def run(self): print("Running on", self.a.device) self.set_device(self.a.device) np.random.seed(self.a.seed) torch.manual_seed(self.a.seed) torch.backends.cudnn.benchmark = True # build text event vocab and ee_role vocab WordsField = Field(lower=True, include_lengths=True, batch_first=True) PosTagsField = Field(lower=True, batch_first=True) EntityLabelsField = MultiTokenField(lower=False, batch_first=True) AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True) EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False) # only for ee LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None) EventsField = EventField(lower=False, batch_first=True) SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) colcc = 'combined-parsing' train_ee_set = ACE2005Dataset( path=self.a.train_ee, fields={ "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField) }, amr=False, keep_events=1) pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15)) LabelField.build_vocab(train_ee_set.LABEL, vectors=pretrained_embedding) EventsField.build_vocab(train_ee_set.EVENT, vectors=pretrained_embedding) # consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME] # # print("O label is", consts.O_LABEL) # consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME] # # print("O label for AE is", consts.ROLE_O_LABEL) # create testing set if self.a.test_sr: log('loading corpus from %s' % self.a.test_sr) transform = transforms.Compose([ transforms.Resize(256), transforms.RandomHorizontalFlip(), transforms.RandomCrop(224), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) vocab_noun = Vocab(os.path.join(self.a.vocab, 'vocab_situation_noun.pkl'), load=True) vocab_role = Vocab(os.path.join(self.a.vocab, 'vocab_situation_role.pkl'), load=True) vocab_verb = Vocab(os.path.join(self.a.vocab, 'vocab_situation_verb.pkl'), load=True) # train_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb, # EventsField.vocab.stoi, LabelField.vocab.stoi, # self.a.imsitu_ontology_file, # self.a.train_sr, self.a.verb_mapping_file, # self.a.object_class_map_file, self.a.object_detection_pkl_file, # self.a.object_detection_threshold, # transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, # load_object=self.a.add_object, filter_place=self.a.filter_place) # dev_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb, # EventsField.vocab.stoi, LabelField.vocab.stoi, # self.a.imsitu_ontology_file, # self.a.dev_sr, self.a.verb_mapping_file, # self.a.object_class_map_file, self.a.object_detection_pkl_file, # self.a.object_detection_threshold, # transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, # load_object=self.a.add_object, filter_place=self.a.filter_place) test_sr_set = ImSituDataset( self.a.image_dir, vocab_noun, vocab_role, vocab_verb, EventsField.vocab.stoi, LabelField.vocab.stoi, self.a.imsitu_ontology_file, self.a.test_sr, self.a.verb_mapping_file, self.a.object_class_map_file, self.a.object_detection_pkl_file, self.a.object_detection_threshold, transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, load_object=self.a.add_object, filter_place=self.a.filter_place) embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to( self.device) embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to( self.device) embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to( self.device) # consts.O_LABEL = vocab_verb.word2id['0'] # verb?? # consts.ROLE_O_LABEL = vocab_role.word2id["OTHER"] #??? # self.a.label_weight = torch.ones([len(vocab_sr.id2word)]) * 5 # more important to learn # self.a.label_weight[consts.O_LABEL] = 1.0 #??? if not self.a.hps_path: self.a.hps = eval(self.a.hps) if self.a.textontology: if "wvemb_size" not in self.a.hps: self.a.hps["wvemb_size"] = len(LabelField.vocab.stoi) if "wremb_size" not in self.a.hps: self.a.hps["wremb_size"] = len(EventsField.vocab.itos) if "wnemb_size" not in self.a.hps: self.a.hps["wnemb_size"] = len(vocab_noun.id2word) if "oc" not in self.a.hps: self.a.hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.hps: self.a.hps["ae_oc"] = len(EventsField.vocab.itos) else: if "wvemb_size" not in self.a.hps: self.a.hps["wvemb_size"] = len(vocab_verb.id2word) if "wremb_size" not in self.a.hps: self.a.hps["wremb_size"] = len(vocab_role.id2word) if "wnemb_size" not in self.a.hps: self.a.hps["wnemb_size"] = len(vocab_noun.id2word) if "oc" not in self.a.hps: self.a.hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.hps: self.a.hps["ae_oc"] = len(EventsField.vocab.itos) tester = self.get_tester() if self.a.textontology: if self.a.finetune: log('init model from ' + self.a.finetune) model = load_sr_model(self.a.hps, embeddingMatrix_noun, LabelField.vocab.vectors, EventsField.vocab.vectors, self.a.finetune, self.device, add_object=self.a.add_object) log('sr model loaded, there are %i sets of params' % len(model.parameters_requires_grads())) else: model = load_sr_model(self.a.hps, embeddingMatrix_noun, LabelField.vocab.vectors, EventsField.vocab.vectors, None, self.device, add_object=self.a.add_object) log('sr model created from scratch, there are %i sets of params' % len(model.parameters_requires_grads())) else: if self.a.finetune: log('init model from ' + self.a.finetune) model = load_sr_model(self.a.hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, self.a.finetune, self.device, add_object=self.a.add_object) log('sr model loaded, there are %i sets of params' % len(model.parameters_requires_grads())) else: model = load_sr_model(self.a.hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, None, self.device, add_object=self.a.add_object) log('sr model created from scratch, there are %i sets of params' % len(model.parameters_requires_grads())) # for name, para in model.named_parameters(): # if para.requires_grad: # print(name) # exit(1) log('init complete\n') if not os.path.exists(self.a.out): os.mkdir(self.a.out) self.a.word_i2s = vocab_noun.id2word # if self.a.textontology: self.a.acelabel_i2s = LabelField.vocab.itos self.a.acerole_i2s = EventsField.vocab.itos # with open(os.path.join(self.a.out, "label_s2i.vec"), "wb") as f: # pickle.dump(LabelField.vocab.stoi, f) # with open(os.path.join(self.a.out, "role_s2i.vec"), "wb") as f: # pickle.dump(EventsField.vocab.stoi, f) # with open(os.path.join(self.a.out, "label_i2s.vec"), "wb") as f: # pickle.dump(LabelField.vocab.itos, f) # with open(os.path.join(self.a.out, "role_i2s.vec"), "wb") as f: # pickle.dump(EventsField.vocab.itos, f) # else: self.a.label_i2s = vocab_verb.id2word #LabelField.vocab.itos self.a.role_i2s = vocab_role.id2word # save as Vocab writer = SummaryWriter(os.path.join(self.a.out, "exp")) self.a.writer = writer # with open(os.path.join(self.a.out, "sr_hyps.json"), "w") as f: # json.dump(self.a.hps, f) test_iter = torch.utils.data.DataLoader(dataset=test_sr_set, batch_size=self.a.batch, shuffle=False, num_workers=2, collate_fn=image_collate_fn) verb_roles = test_sr_set.get_verb_role_mapping() if 'visualize_path' not in self.a: visualize_path = None else: visualize_path = self.a.visualize_path test_loss, test_verb_p, test_verb_r, test_verb_f1, \ test_role_p, test_role_r, test_role_f1, \ test_noun_p, test_noun_r, test_noun_f1, \ test_triple_p, test_triple_r, test_triple_f1, \ test_noun_p_relaxed, test_noun_r_relaxed, test_noun_f1_relaxed, \ test_triple_p_relaxed, test_triple_r_relaxed, test_triple_f1_relaxed = run_over_data_sr(data_iter=test_iter, optimizer=None, model=model, need_backward=False, MAX_STEP=ceil(len( test_sr_set) / self.a.batch), tester=tester, hyps=model.hyperparams, device=model.device, maxnorm=self.a.maxnorm, word_i2s=self.a.word_i2s, label_i2s=self.a.label_i2s, role_i2s=self.a.role_i2s, verb_roles=verb_roles, load_object=self.a.add_object, visualize_path=visualize_path, save_output=os.path.join( self.a.out, "test_final.txt")) print("\nFinally test loss: ", test_loss, "\ntest verb p: ", test_verb_p, " test verb r: ", test_verb_r, " test verb f1: ", test_verb_f1, "\ntest role p: ", test_role_p, " test role r: ", test_role_r, " test role f1: ", test_role_f1, "\ntest noun p: ", test_noun_p, " test noun r: ", test_noun_r, " test noun f1: ", test_noun_f1, "\ntest triple p: ", test_triple_p, " test triple r: ", test_triple_r, " test triple f1: ", test_triple_f1, "\ntest noun p relaxed: ", test_noun_p_relaxed, " test noun r relaxed: ", test_noun_r_relaxed, " test noun f1 relaxed: ", test_noun_f1_relaxed, "\ntest triple p relaxed: ", test_triple_p_relaxed, " test triple r relaxed: ", test_triple_r_relaxed, " test triple f1 relaxed: ", test_triple_f1_relaxed)
def run(self): print("Running on", self.a.device, ', optimizer=', self.a.optimizer, ', lr=%d' % self.a.lr) self.set_device(self.a.device) np.random.seed(self.a.seed) torch.manual_seed(self.a.seed) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = False # create training set if self.a.train_ee: log('loading event extraction corpus from %s' % self.a.train_ee) WordsField = Field(lower=True, include_lengths=True, batch_first=True) PosTagsField = Field(lower=True, batch_first=True) EntityLabelsField = MultiTokenField(lower=False, batch_first=True) AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True) LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None) EventsField = EventField(lower=False, batch_first=True) EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False) if self.a.amr: colcc = 'amr-colcc' else: colcc = 'stanford-colcc' print(colcc) train_ee_set = ACE2005Dataset(path=self.a.train_ee, fields={ "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField) }, amr=self.a.amr, keep_events=1) dev_ee_set = ACE2005Dataset(path=self.a.dev_ee, fields={ "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField) }, amr=self.a.amr, keep_events=0) test_ee_set = ACE2005Dataset(path=self.a.test_ee, fields={ "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField) }, amr=self.a.amr, keep_events=0) if self.a.webd: pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial( torch.nn.init.uniform_, a=-0.15, b=0.15)) WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, vectors=pretrained_embedding) LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL, vectors=pretrained_embedding) EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT, vectors=pretrained_embedding) else: WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS) LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL) EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT) PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS) EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS) consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME] # print("O label is", consts.O_LABEL) consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME] # print("O label for AE is", consts.ROLE_O_LABEL) dev_ee_set1 = ACE2005Dataset(path=self.a.dev_ee, fields={ "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField) }, amr=self.a.amr, keep_events=1, only_keep=True) test_ee_set1 = ACE2005Dataset(path=self.a.test_ee, fields={ "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField) }, amr=self.a.amr, keep_events=1, only_keep=True) print("train set length", len(train_ee_set)) print("dev set length", len(dev_ee_set)) print("dev set 1/1 length", len(dev_ee_set1)) print("test set length", len(test_ee_set)) print("test set 1/1 length", len(test_ee_set1)) self.a.label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5 self.a.label_weight[consts.O_LABEL] = 1.0 self.a.arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5 self.a.arg_weight[consts.ROLE_O_LABEL] = 1.0 #???? # add role mask # self.a.role_mask = event_role_mask(self.a.test_ee, self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi, EventsField.vocab.stoi, self.device) self.a.role_mask = None self.a.hps = eval(self.a.hps) if "wemb_size" not in self.a.hps: self.a.hps["wemb_size"] = len(WordsField.vocab.itos) if "pemb_size" not in self.a.hps: self.a.hps["pemb_size"] = len(PosTagsField.vocab.itos) if "psemb_size" not in self.a.hps: self.a.hps["psemb_size"] = max([ train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest() ]) + 2 if "eemb_size" not in self.a.hps: self.a.hps["eemb_size"] = len(EntityLabelsField.vocab.itos) if "oc" not in self.a.hps: self.a.hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.hps: self.a.hps["ae_oc"] = len(EventsField.vocab.itos) tester = self.get_tester(LabelField.vocab.itos, EventsField.vocab.itos) if self.a.finetune: log('init model from ' + self.a.finetune) model = load_ee_model(self.a.hps, self.a.finetune, WordsField.vocab.vectors, self.device) log('model loaded, there are %i sets of params' % len(model.parameters_requires_grads())) log(model.parameters_requires_grads()) else: model = load_ee_model(self.a.hps, None, WordsField.vocab.vectors, self.device) log('model created from scratch, there are %i sets of params' % len(model.parameters_requires_grads())) log(model.parameters_requires_grads()) if self.a.optimizer == "adadelta": optimizer_constructor = partial( torch.optim.Adadelta, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay) elif self.a.optimizer == "adam": optimizer_constructor = partial( torch.optim.Adam, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay) else: optimizer_constructor = partial( torch.optim.SGD, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay, momentum=0.9) log('optimizer in use: %s' % str(self.a.optimizer)) if not os.path.exists(self.a.out): os.mkdir(self.a.out) with open(os.path.join(self.a.out, "word.vec"), "wb") as f: pickle.dump(WordsField.vocab, f) with open(os.path.join(self.a.out, "pos.vec"), "wb") as f: pickle.dump(PosTagsField.vocab.stoi, f) with open(os.path.join(self.a.out, "entity.vec"), "wb") as f: pickle.dump(EntityLabelsField.vocab.stoi, f) with open(os.path.join(self.a.out, "label_s2i.vec"), "wb") as f: pickle.dump(LabelField.vocab.stoi, f) with open(os.path.join(self.a.out, "role_s2i.vec"), "wb") as f: pickle.dump(EventsField.vocab.stoi, f) with open(os.path.join(self.a.out, "label_i2s.vec"), "wb") as f: pickle.dump(LabelField.vocab.itos, f) with open(os.path.join(self.a.out, "role_i2s.vec"), "wb") as f: pickle.dump(EventsField.vocab.itos, f) with open(os.path.join(self.a.out, "ee_hyps.json"), "w") as f: json.dump(self.a.hps, f) log('init complete\n') self.a.word_i2s = WordsField.vocab.itos self.a.label_i2s = LabelField.vocab.itos self.a.role_i2s = EventsField.vocab.itos writer = SummaryWriter(os.path.join(self.a.out, "exp")) self.a.writer = writer ee_train(model=model, train_set=train_ee_set, dev_set=dev_ee_set, test_set=test_ee_set, optimizer_constructor=optimizer_constructor, epochs=self.a.epochs, tester=tester, parser=self.a, other_testsets={ "dev 1/1": dev_ee_set1, "test 1/1": test_ee_set1, }, role_mask=self.a.role_mask) log('Done!')
def run(self): print("Running on", self.a.device) self.set_device(self.a.device) np.random.seed(self.a.seed) torch.manual_seed(self.a.seed) torch.backends.cudnn.benchmark = True # create training set if self.a.train: log('loading corpus from %s' % self.a.train) transform = transforms.Compose([ transforms.Resize(256), transforms.RandomHorizontalFlip(), transforms.RandomCrop(224), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) IMAGEIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True) WordsField = Field(lower=True, include_lengths=True, batch_first=True) PosTagsField = Field(lower=True, batch_first=True) EntityLabelsField = MultiTokenField(lower=False, batch_first=True) AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True) EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False) if self.a.amr: colcc = 'amr-colcc' else: colcc = 'stanford-colcc' print(colcc) train_set = GroundingDataset( path=self.a.train, img_dir=self.a.img_dir, fields={ "id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file, object_detection_threshold=self.a.object_detection_threshold, ) dev_set = GroundingDataset( path=self.a.dev, img_dir=self.a.img_dir, fields={ "id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file, object_detection_threshold=self.a.object_detection_threshold, ) test_set = GroundingDataset( path=self.a.test, img_dir=self.a.img_dir, fields={ "id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file, object_detection_threshold=self.a.object_detection_threshold, ) if self.a.webd: pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial( torch.nn.init.uniform_, a=-0.15, b=0.15)) WordsField.build_vocab(train_set.WORDS, dev_set.WORDS, vectors=pretrained_embedding) else: WordsField.build_vocab(train_set.WORDS, dev_set.WORDS) # WordsField.build_vocab(train_set.WORDS, dev_set.WORDS) PosTagsField.build_vocab(train_set.POSTAGS, dev_set.POSTAGS) EntityLabelsField.build_vocab(train_set.ENTITYLABELS, dev_set.ENTITYLABELS) # sr model initialization self.a.sr_hps = eval(self.a.sr_hps) vocab_noun = Vocab(os.path.join(self.a.vocab, 'vocab_situation_noun.pkl'), load=True) vocab_role = Vocab(os.path.join(self.a.vocab, 'vocab_situation_role.pkl'), load=True) vocab_verb = Vocab(os.path.join(self.a.vocab, 'vocab_situation_verb.pkl'), load=True) embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to( self.device) embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to( self.device) embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to( self.device) if "wvemb_size" not in self.a.sr_hps: self.a.sr_hps["wvemb_size"] = len(vocab_verb.id2word) if "wremb_size" not in self.a.sr_hps: self.a.sr_hps["wremb_size"] = len(vocab_role.id2word) if "wnemb_size" not in self.a.sr_hps: self.a.sr_hps["wnemb_size"] = len(vocab_noun.id2word) if "ae_oc" not in self.a.sr_hps: self.a.sr_hps["ae_oc"] = len(vocab_role.id2word) self.a.ee_hps = eval(self.a.ee_hps) if "wemb_size" not in self.a.ee_hps: self.a.ee_hps["wemb_size"] = len(WordsField.vocab.itos) if "pemb_size" not in self.a.ee_hps: self.a.ee_hps["pemb_size"] = len(PosTagsField.vocab.itos) if "psemb_size" not in self.a.ee_hps: self.a.ee_hps["psemb_size"] = max( [train_set.longest(), dev_set.longest(), test_set.longest()]) + 2 if "eemb_size" not in self.a.ee_hps: self.a.ee_hps["eemb_size"] = len(EntityLabelsField.vocab.itos) if "oc" not in self.a.ee_hps: self.a.ee_hps["oc"] = 36 #??? if "ae_oc" not in self.a.ee_hps: self.a.ee_hps["ae_oc"] = 20 #??? tester = self.get_tester() if self.a.finetune_sr: log('init sr model from ' + self.a.finetune_sr) sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, self.a.finetune_sr, self.device) log('sr model loaded, there are %i sets of params' % len(sr_model.parameters_requires_grads())) else: sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, None, self.device) log('sr model created from scratch, there are %i sets of params' % len(sr_model.parameters_requires_grads())) if self.a.finetune_ee: log('init model from ' + self.a.finetune_ee) ee_model = load_ee_model(self.a.ee_hps, self.a.finetune_ee, WordsField.vocab.vectors, self.device) log('model loaded, there are %i sets of params' % len(ee_model.parameters_requires_grads())) else: ee_model = load_ee_model(self.a.ee_hps, None, WordsField.vocab.vectors, self.device) log('model created from scratch, there are %i sets of params' % len(ee_model.parameters_requires_grads())) model = GroundingModel(ee_model, sr_model, self.get_device()) if self.a.optimizer == "adadelta": optimizer_constructor = partial( torch.optim.Adadelta, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay) elif self.a.optimizer == "adam": optimizer_constructor = partial( torch.optim.Adam, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay) else: optimizer_constructor = partial( torch.optim.SGD, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay, momentum=0.9) log('optimizer in use: %s' % str(self.a.optimizer)) if not os.path.exists(self.a.out): os.mkdir(self.a.out) with open(os.path.join(self.a.out, "word.vec"), "wb") as f: pickle.dump(WordsField.vocab, f) with open(os.path.join(self.a.out, "pos.vec"), "wb") as f: pickle.dump(PosTagsField.vocab.stoi, f) with open(os.path.join(self.a.out, "entity.vec"), "wb") as f: pickle.dump(EntityLabelsField.vocab.stoi, f) with open(os.path.join(self.a.out, "ee_hyps.json"), "w") as f: json.dump(self.a.ee_hps, f) with open(os.path.join(self.a.out, "sr_hyps.json"), "w") as f: json.dump(self.a.sr_hps, f) log('init complete\n') self.a.word_i2s = vocab_noun.id2word self.a.label_i2s = vocab_verb.id2word # LabelField.vocab.itos self.a.role_i2s = vocab_role.id2word self.a.word_i2s = WordsField.vocab.itos # self.a.label_i2s = LabelField.vocab.itos # self.a.role_i2s = EventsField.vocab.itos writer = SummaryWriter(os.path.join(self.a.out, "exp")) self.a.writer = writer grounding_train( model=model, train_set=train_set, dev_set=dev_set, test_set=test_set, optimizer_constructor=optimizer_constructor, epochs=self.a.epochs, tester=tester, parser=self.a, other_testsets={ # "dev 1/1": dev_set1, # "test 1/1": test_set1, }, transform=transform, vocab_objlabel=vocab_noun.word2id) log('Done!')
def __new__(cls, *args, **kwargs): log('created %s with params %s' % (str(cls), str(args))) instance = super(Model, cls).__new__(cls) instance.__init__(*args, **kwargs) return instance
def run(self): print("Running on", self.a.device) self.set_device(self.a.device) np.random.seed(self.a.seed) torch.manual_seed(self.a.seed) torch.backends.cudnn.benchmark = True # create training set if self.a.test_ee: log('loading event extraction corpus from %s' % self.a.test_ee) WordsField = Field(lower=True, include_lengths=True, batch_first=True) PosTagsField = Field(lower=True, batch_first=True) EntityLabelsField = MultiTokenField(lower=False, batch_first=True) AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True) LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None) EventsField = EventField(lower=False, batch_first=True) EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False) SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) if self.a.amr: colcc = 'simple-parsing' else: colcc = 'combined-parsing' print(colcc) train_ee_set = ACE2005Dataset( path=self.a.train_ee, fields={ "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField) }, amr=self.a.amr, keep_events=1) dev_ee_set = ACE2005Dataset(path=self.a.dev_ee, fields={ "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField) }, amr=self.a.amr, keep_events=0) test_ee_set = ACE2005Dataset(path=self.a.test_ee, fields={ "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField) }, amr=self.a.amr, keep_events=0) if self.a.load_grounding: #################### loading grounding dataset #################### if self.a.train_grounding: log('loading grounding corpus from %s' % self.a.train_grounding) # only for grounding IMAGEIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True) transform = transforms.Compose([ transforms.Resize(256), transforms.RandomHorizontalFlip(), transforms.RandomCrop(224), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) train_grounding_set = GroundingDataset( path=self.a.train_grounding, img_dir=self.a.img_dir_grounding, fields={ "id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file_g, object_detection_threshold=self.a.object_detection_threshold, ) dev_grounding_set = GroundingDataset( path=self.a.dev_grounding, img_dir=self.a.img_dir_grounding, fields={ "id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file_g, object_detection_threshold=self.a.object_detection_threshold, ) # test_grounding_set = GroundingDataset(path=self.a.test_grounding, # img_dir=self.a.img_dir_grounding, # fields={"id": ("IMAGEID", IMAGEIDField), # "sentence_id": ("SENTID", SENTIDField), # "words": ("WORDS", WordsField), # "pos-tags": ("POSTAGS", PosTagsField), # "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), # colcc: ("ADJM", AdjMatrixField), # "all-entities": ("ENTITIES", EntitiesField), # # "image": ("IMAGE", IMAGEField), # }, # transform=transform, # amr=self.a.amr, # load_object=self.a.add_object, # object_ontology_file=self.a.object_class_map_file, # object_detection_pkl_file=self.a.object_detection_pkl_file_g, # object_detection_threshold=self.a.object_detection_threshold, # ) #################### build vocabulary #################### if self.a.webd: pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial( torch.nn.init.uniform_, a=-0.15, b=0.15)) WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS, vectors=pretrained_embedding) else: WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS) PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS, train_grounding_set.POSTAGS, dev_grounding_set.POSTAGS) EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS, train_grounding_set.ENTITYLABELS, dev_grounding_set.ENTITYLABELS) else: if self.a.webd: pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial( torch.nn.init.uniform_, a=-0.15, b=0.15)) WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, vectors=pretrained_embedding) else: WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS) PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS) EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS) LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL) EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT) consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME] # print("O label is", consts.O_LABEL) consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME] # print("O label for AE is", consts.ROLE_O_LABEL) self.a.label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5 self.a.label_weight[consts.O_LABEL] = 1.0 self.a.arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5 # add role mask self.a.role_mask = event_role_mask(self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi, EventsField.vocab.stoi, self.device) # print('self.a.hps', self.a.hps) if not self.a.hps_path: self.a.hps = eval(self.a.hps) if "wemb_size" not in self.a.hps: self.a.hps["wemb_size"] = len(WordsField.vocab.itos) if "pemb_size" not in self.a.hps: self.a.hps["pemb_size"] = len(PosTagsField.vocab.itos) if "psemb_size" not in self.a.hps: self.a.hps["psemb_size"] = max([ train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest() ]) + 2 if "eemb_size" not in self.a.hps: self.a.hps["eemb_size"] = len(EntityLabelsField.vocab.itos) if "oc" not in self.a.hps: self.a.hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.hps: self.a.hps["ae_oc"] = len(EventsField.vocab.itos) tester = self.get_tester(LabelField.vocab.itos, EventsField.vocab.itos) visualizer = EDVisualizer(self.a.gt_voa_text) if self.a.finetune: log('init model from ' + self.a.finetune) model = load_ee_model(self.a.hps, self.a.finetune, WordsField.vocab.vectors, self.device) log('model loaded, there are %i sets of params' % len(model.parameters_requires_grads())) else: model = load_ee_model(self.a.hps, None, WordsField.vocab.vectors, self.device) log('model created from scratch, there are %i sets of params' % len(model.parameters_requires_grads())) self.a.word_i2s = WordsField.vocab.itos self.a.label_i2s = LabelField.vocab.itos self.a.role_i2s = EventsField.vocab.itos writer = SummaryWriter(os.path.join(self.a.out, "exp")) self.a.writer = writer # train_iter = BucketIterator(train_ee_set, batch_size=self.a.batch, # train=True, shuffle=False, device=-1, # sort_key=lambda x: len(x.POSTAGS)) # dev_iter = BucketIterator(dev_ee_set, batch_size=self.a.batch, train=False, # shuffle=False, device=-1, # sort_key=lambda x: len(x.POSTAGS)) # test_iter = BucketIterator(test_ee_set, batch_size=self.a.batch, train=False, # shuffle=False, device=-1, # sort_key=lambda x: len(x.POSTAGS)) test_m2e2_set = ACE2005Dataset( path=self.a.gt_voa_text, fields={ "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField) }, amr=self.a.amr, keep_events=self.a.keep_events) # test_m2e2_set = M2E2Dataset(path=gt_voa_text, # img_dir=voa_image_dir, # fields={"image": ("IMAGEID", IMAGEIDField), # "sentence_id": ("SENTID", SENTIDField), # "words": ("WORDS", WordsField), # "pos-tags": ("POSTAGS", PosTagsField), # "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), # colcc: ("ADJM", AdjMatrixField), # "all-entities": ("ENTITIES", EntitiesField), # # "image": ("IMAGE", IMAGEField), # "golden-event-mentions": ("LABEL", LabelField), # "all-events": ("EVENT", EventsField), # }, # transform=transform, # amr=self.a.amr, # load_object=self.a.add_object, # object_ontology_file=self.a.object_class_map_file, # object_detection_pkl_file=self.a.object_detection_pkl_file, # object_detection_threshold=self.a.object_detection_threshold, # ) test_m2e2_iter = BucketIterator(test_m2e2_set, batch_size=1, train=False, shuffle=False, device=-1, sort_key=lambda x: len(x.POSTAGS)) print("\nStarting testing ...\n") # Testing Phrase test_loss, test_ed_p, test_ed_r, test_ed_f1, \ test_ae_p, test_ae_r, test_ae_f1 = run_over_data(data_iter=test_m2e2_iter, optimizer=None, model=model, need_backward=False, MAX_STEP=len(test_m2e2_iter), tester=tester, visualizer=visualizer, hyps=model.hyperparams, device=model.device, maxnorm=self.a.maxnorm, word_i2s=self.a.word_i2s, label_i2s=self.a.label_i2s, role_i2s=self.a.role_i2s, weight=self.a.label_weight, arg_weight=self.a.arg_weight, save_output=os.path.join( self.a.out, "test_final.txt"), role_mask=self.a.role_mask) print("\nFinally test loss: ", test_loss, "\ntest ed p: ", test_ed_p, " test ed r: ", test_ed_r, " test ed f1: ", test_ed_f1, "\ntest ae p: ", test_ae_p, " test ae r: ", test_ae_r, " test ae f1: ", test_ae_f1)
def run(self): print("Running on", self.a.device) self.set_device(self.a.device) np.random.seed(self.a.seed) torch.manual_seed(self.a.seed) torch.backends.cudnn.benchmark = True # build text event vocab and ee_role vocab WordsField = Field(lower=True, include_lengths=True, batch_first=True) PosTagsField = Field(lower=True, batch_first=True) EntityLabelsField = MultiTokenField(lower=False, batch_first=True) AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True) EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False) # only for ee LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None) EventsField = EventField(lower=False, batch_first=True) colcc = 'stanford-colcc' train_ee_set = ACE2005Dataset(path=self.a.train_ee, fields={ "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField) }, amr=False, keep_events=1) pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15)) LabelField.build_vocab(train_ee_set.LABEL, vectors=pretrained_embedding) EventsField.build_vocab(train_ee_set.EVENT, vectors=pretrained_embedding) # consts.O_LABEL = LabelField.vocab.stoi["O"] # # print("O label is", consts.O_LABEL) # consts.ROLE_O_LABEL = EventsField.vocab.stoi["OTHER"] # print("O label for AE is", consts.ROLE_O_LABEL) # create training set if self.a.train_sr: log('loading corpus from %s' % self.a.train_sr) train_sr_set = ImSituDataset( self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, LabelField.vocab.stoi, EventsField.vocab.stoi, self.a.imsitu_ontology_file, self.a.train_sr, self.a.verb_mapping_file, self.a.object_class_map_file, self.a.object_detection_pkl_file, self.a.object_detection_threshold, self.transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, load_object=self.a.add_object, filter_place=self.a.filter_place) dev_sr_set = ImSituDataset( self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, LabelField.vocab.stoi, EventsField.vocab.stoi, self.a.imsitu_ontology_file, self.a.dev_sr, self.a.verb_mapping_file, self.a.object_class_map_file, self.a.object_detection_pkl_file, self.a.object_detection_threshold, self.transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, load_object=self.a.add_object, filter_place=self.a.filter_place) test_sr_set = ImSituDataset( self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, LabelField.vocab.stoi, EventsField.vocab.stoi, self.a.imsitu_ontology_file, self.a.test_sr, self.a.verb_mapping_file, self.a.object_class_map_file, self.a.object_detection_pkl_file, self.a.object_detection_threshold, self.transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, load_object=self.a.add_object, filter_place=self.a.filter_place) embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to( self.device) embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to( self.device) embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to( self.device) # consts.O_LABEL = self.vocab_verb.word2id['0'] # verb?? # consts.ROLE_O_LABEL = self.vocab_role.word2id["OTHER"] #??? # self.a.label_weight = torch.ones([len(vocab_sr.id2word)]) * 5 # more important to learn # self.a.label_weight[consts.O_LABEL] = 1.0 #??? self.a.hps = eval(self.a.hps) if self.a.textontology: if "wvemb_size" not in self.a.hps: self.a.hps["wvemb_size"] = len(LabelField.vocab.stoi) if "wremb_size" not in self.a.hps: self.a.hps["wremb_size"] = len(EventsField.vocab.itos) if "wnemb_size" not in self.a.hps: self.a.hps["wnemb_size"] = len(self.vocab_noun.id2word) if "oc" not in self.a.hps: self.a.hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.hps: self.a.hps["ae_oc"] = len(EventsField.vocab.itos) else: if "wvemb_size" not in self.a.hps: self.a.hps["wvemb_size"] = len(self.vocab_verb.id2word) if "wremb_size" not in self.a.hps: self.a.hps["wremb_size"] = len(self.vocab_role.id2word) if "wnemb_size" not in self.a.hps: self.a.hps["wnemb_size"] = len(self.vocab_noun.id2word) if "oc" not in self.a.hps: self.a.hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.hps: self.a.hps["ae_oc"] = len(EventsField.vocab.itos) tester = self.get_tester() if self.a.textontology: if self.a.finetune: log('init model from ' + self.a.finetune) model = load_sr_model(self.a.hps, embeddingMatrix_noun, LabelField.vocab.vectors, EventsField.vocab.vectors, self.a.finetune, self.device, add_object=self.a.add_object) log('sr model loaded, there are %i sets of params' % len(model.parameters_requires_grads())) else: model = load_sr_model(self.a.hps, embeddingMatrix_noun, LabelField.vocab.vectors, EventsField.vocab.vectors, None, self.device, add_object=self.a.add_object) log('sr model created from scratch, there are %i sets of params' % len(model.parameters_requires_grads())) else: if self.a.finetune: log('init model from ' + self.a.finetune) model = load_sr_model(self.a.hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, self.a.finetune, self.device, add_object=self.a.add_object) log('sr model loaded, there are %i sets of params' % len(model.parameters_requires_grads())) else: model = load_sr_model(self.a.hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, None, self.device, add_object=self.a.add_object) log('sr model created from scratch, there are %i sets of params' % len(model.parameters_requires_grads())) if self.a.optimizer == "adadelta": optimizer_constructor = partial( torch.optim.Adadelta, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay) elif self.a.optimizer == "adam": optimizer_constructor = partial( torch.optim.Adam, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay) else: optimizer_constructor = partial( torch.optim.SGD, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay, momentum=0.9) # for name, para in model.named_parameters(): # if para.requires_grad: # print(name) # exit(1) log('optimizer in use: %s' % str(self.a.optimizer)) log('init complete\n') if not os.path.exists(self.a.out): os.mkdir(self.a.out) self.a.word_i2s = self.vocab_noun.id2word # if self.a.textontology: self.a.acelabel_i2s = LabelField.vocab.itos self.a.acerole_i2s = EventsField.vocab.itos with open(os.path.join(self.a.out, "label_s2i.vec"), "wb") as f: pickle.dump(LabelField.vocab.stoi, f) with open(os.path.join(self.a.out, "role_s2i.vec"), "wb") as f: pickle.dump(EventsField.vocab.stoi, f) with open(os.path.join(self.a.out, "label_i2s.vec"), "wb") as f: pickle.dump(LabelField.vocab.itos, f) with open(os.path.join(self.a.out, "role_i2s.vec"), "wb") as f: pickle.dump(EventsField.vocab.itos, f) # else: self.a.label_i2s = self.vocab_verb.id2word #LabelField.vocab.itos self.a.role_i2s = self.vocab_role.id2word # save as Vocab writer = SummaryWriter(os.path.join(self.a.out, "exp")) self.a.writer = writer with open(os.path.join(self.a.out, "sr_hyps.json"), "w") as f: json.dump(self.a.hps, f) sr_train( model=model, train_set=train_sr_set, dev_set=dev_sr_set, test_set=test_sr_set, optimizer_constructor=optimizer_constructor, epochs=self.a.epochs, tester=tester, parser=self.a, other_testsets={ # "dev 1/1": dev_sr_loader, # "test 1/1": test_sr_loader, }) log('Done!')
def run(self): print("Running on", self.a.device) self.set_device(self.a.device) np.random.seed(self.a.seed) torch.manual_seed(self.a.seed) torch.backends.cudnn.benchmark = True #################### loading event extraction dataset #################### if self.a.test_ee: log('testing event extraction corpus from %s' % self.a.test_ee) if self.a.test_ee: log('testing event extraction corpus from %s' % self.a.test_ee) # both for grounding and ee WordsField = Field(lower=True, include_lengths=True, batch_first=True) PosTagsField = Field(lower=True, batch_first=True) EntityLabelsField = MultiTokenField(lower=False, batch_first=True) AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True) EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False) # only for ee LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None) EventsField = EventField(lower=False, batch_first=True) SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) if self.a.amr: colcc = 'simple-parsing' else: colcc = 'combined-parsing' print(colcc) train_ee_set = ACE2005Dataset(path=self.a.train_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=1) dev_ee_set = ACE2005Dataset(path=self.a.dev_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=0) # test_ee_set = ACE2005Dataset(path=self.a.test_ee, # fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), # "pos-tags": ("POSTAGS", PosTagsField), # "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), # colcc: ("ADJM", AdjMatrixField), # "golden-event-mentions": ("LABEL", LabelField), # "all-events": ("EVENT", EventsField), # "all-entities": ("ENTITIES", EntitiesField)}, # amr=self.a.amr, keep_events=0) print('self.a.train_ee', self.a.train_ee) LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL) print('LabelField.vocab.stoi', LabelField.vocab.stoi) EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT) print('EventsField.vocab.stoi', EventsField.vocab.stoi) print('len(EventsField.vocab.itos)', len(EventsField.vocab.itos)) print('len(EventsField.vocab.stoi)', len(EventsField.vocab.stoi)) #################### loading SR dataset #################### # both for grounding and sr if self.a.train_sr: log('loading corpus from %s' % self.a.train_sr) transform = transforms.Compose([ transforms.Resize(256), transforms.RandomHorizontalFlip(), transforms.RandomCrop(224), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) vocab_noun = Vocab(os.path.join(self.a.vocab, 'vocab_situation_noun.pkl'), load=True) vocab_role = Vocab(os.path.join(self.a.vocab, 'vocab_situation_role.pkl'), load=True) vocab_verb = Vocab(os.path.join(self.a.vocab, 'vocab_situation_verb.pkl'), load=True) # only need get_role_mask() and sr_mapping() train_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb, EventsField.vocab.stoi, LabelField.vocab.stoi, self.a.imsitu_ontology_file, self.a.train_sr, self.a.verb_mapping_file, None, None, 0, transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, load_object=False, filter_place=self.a.filter_place) #################### loading grounding dataset #################### if self.a.train_grounding: log('loading grounding corpus from %s' % self.a.train_grounding) # only for grounding IMAGEIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True) train_grounding_set = GroundingDataset(path=self.a.train_grounding, img_dir=None, fields={"id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr) dev_grounding_set = GroundingDataset(path=self.a.dev_grounding, img_dir=None, fields={"id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr) # test_grounding_set = GroundingDataset(path=self.a.test_grounding, # img_dir=None, # fields={"id": ("IMAGEID", IMAGEIDField), # "sentence_id": ("SENTID", SENTIDField), # "words": ("WORDS", WordsField), # "pos-tags": ("POSTAGS", PosTagsField), # "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), # colcc: ("ADJM", AdjMatrixField), # "all-entities": ("ENTITIES", EntitiesField), # # "image": ("IMAGE", IMAGEField), # }, # transform=transform, # amr=self.a.amr) #################### build vocabulary #################### if self.a.webd: pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15)) WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS, vectors=pretrained_embedding) else: WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS) PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS, train_grounding_set.POSTAGS, dev_grounding_set.POSTAGS) EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS, train_grounding_set.ENTITYLABELS, dev_grounding_set.ENTITYLABELS) consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME] # print("O label is", consts.O_LABEL) consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME] # print("O label for AE is", consts.ROLE_O_LABEL) # dev_ee_set1 = ACE2005Dataset(path=self.a.dev_ee, # fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), # "pos-tags": ("POSTAGS", PosTagsField), # "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), # colcc: ("ADJM", AdjMatrixField), # "golden-event-mentions": ("LABEL", LabelField), # "all-events": ("EVENT", EventsField), # "all-entities": ("ENTITIES", EntitiesField)}, # amr=self.a.amr, keep_events=1, only_keep=True) # # test_ee_set1 = ACE2005Dataset(path=self.a.test_ee, # fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), # "pos-tags": ("POSTAGS", PosTagsField), # "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), # colcc: ("ADJM", AdjMatrixField), # "golden-event-mentions": ("LABEL", LabelField), # "all-events": ("EVENT", EventsField), # "all-entities": ("ENTITIES", EntitiesField)}, # amr=self.a.amr, keep_events=1, only_keep=True) # print("train set length", len(train_ee_set)) # # print("dev set length", len(dev_ee_set)) # print("dev set 1/1 length", len(dev_ee_set1)) # # print("test set length", len(test_ee_set)) # print("test set 1/1 length", len(test_ee_set1)) # sr model initialization if not self.a.sr_hps_path: self.a.sr_hps = eval(self.a.sr_hps) embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to(self.device) embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to(self.device) embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to(self.device) if "wvemb_size" not in self.a.sr_hps: self.a.sr_hps["wvemb_size"] = len(vocab_verb.id2word) if "wremb_size" not in self.a.sr_hps: self.a.sr_hps["wremb_size"] = len(vocab_role.id2word) if "wnemb_size" not in self.a.sr_hps: self.a.sr_hps["wnemb_size"] = len(vocab_noun.id2word) # if "ae_oc" not in self.a.sr_hps: # self.a.sr_hps["ae_oc"] = len(vocab_role.id2word) # self.a.ee_label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5 # self.a.ee_label_weight[consts.O_LABEL] = 1.0 # self.a.ee_arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5 # if not self.a.ee_hps_path: # self.a.ee_hps = eval(self.a.ee_hps) # if "wemb_size" not in self.a.ee_hps: # self.a.ee_hps["wemb_size"] = len(WordsField.vocab.itos) # if "pemb_size" not in self.a.ee_hps: # self.a.ee_hps["pemb_size"] = len(PosTagsField.vocab.itos) # if "psemb_size" not in self.a.ee_hps: # # self.a.ee_hps["psemb_size"] = max([train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2 # self.a.ee_hps["psemb_size"] = max([train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest(), train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2 # if "eemb_size" not in self.a.ee_hps: # self.a.ee_hps["eemb_size"] = len(EntityLabelsField.vocab.itos) # if "oc" not in self.a.ee_hps: # self.a.ee_hps["oc"] = len(LabelField.vocab.itos) # if "ae_oc" not in self.a.ee_hps: # self.a.ee_hps["ae_oc"] = len(EventsField.vocab.itos) if "oc" not in self.a.sr_hps: self.a.sr_hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.sr_hps: self.a.sr_hps["ae_oc"] = len(EventsField.vocab.itos) ace_classifier = ACEClassifier(self.a.sr_hps["wemb_dim"], self.a.sr_hps["oc"], self.a.sr_hps["ae_oc"], self.device) ee_model = None # # if self.a.score_ee: # if self.a.finetune_ee: # log('init ee model from ' + self.a.finetune_ee) # ee_model = load_ee_model(self.a.ee_hps, self.a.finetune_ee, WordsField.vocab.vectors, self.device, ace_classifier) # log('ee model loaded, there are %i sets of params' % len(ee_model.parameters_requires_grads())) # else: # ee_model = load_ee_model(self.a.ee_hps, None, WordsField.vocab.vectors, self.device, ace_classifier) # log('ee model created from scratch, there are %i sets of params' % len(ee_model.parameters_requires_grads())) # if self.a.score_sr: if self.a.finetune_sr: log('init sr model from ' + self.a.finetune_sr) sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, self.a.finetune_sr, self.device, ace_classifier, add_object=self.a.add_object) log('sr model loaded, there are %i sets of params' % len(sr_model.parameters_requires_grads())) else: sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, None, self.device, ace_classifier, add_object=self.a.add_object) log('sr model created from scratch, there are %i sets of params' % len(sr_model.parameters_requires_grads())) model = GroundingModel(ee_model, sr_model, self.get_device()) # ee_model = torch.nn.DataParallel(ee_model) # sr_model = torch.nn.DataParallel(sr_model) # model = torch.nn.DataParallel(model) # if self.a.optimizer == "adadelta": # optimizer_constructor = partial(torch.optim.Adadelta, params=model.parameters_requires_grads(), # weight_decay=self.a.l2decay) # elif self.a.optimizer == "adam": # optimizer_constructor = partial(torch.optim.Adam, params=model.parameters_requires_grads(), # weight_decay=self.a.l2decay) # else: # optimizer_constructor = partial(torch.optim.SGD, params=model.parameters_requires_grads(), # weight_decay=self.a.l2decay, # momentum=0.9) # log('optimizer in use: %s' % str(self.a.optimizer)) if not os.path.exists(self.a.out): os.mkdir(self.a.out) # with open(os.path.join(self.a.out, "word.vec"), "wb") as f: # pickle.dump(WordsField.vocab, f) # with open(os.path.join(self.a.out, "pos.vec"), "wb") as f: # pickle.dump(PosTagsField.vocab.stoi, f) # with open(os.path.join(self.a.out, "entity.vec"), "wb") as f: # pickle.dump(EntityLabelsField.vocab.stoi, f) with open(os.path.join(self.a.out, "label.vec"), "wb") as f: pickle.dump(LabelField.vocab.stoi, f) with open(os.path.join(self.a.out, "role.vec"), "wb") as f: pickle.dump(EventsField.vocab.stoi, f) log('init complete\n') # # ee mappings # self.a.ee_word_i2s = WordsField.vocab.itos # self.a.ee_label_i2s = LabelField.vocab.itos # self.a.ee_role_i2s = EventsField.vocab.itos # self.a.ee_role_mask = None # if self.a.apply_ee_role_mask: # self.a.ee_role_mask = event_role_mask(self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi, EventsField.vocab.stoi, self.device) # sr mappings self.a.sr_word_i2s = vocab_noun.id2word self.a.sr_label_i2s = vocab_verb.id2word # LabelField.vocab.itos self.a.sr_role_i2s = vocab_role.id2word self.a.role_masks = train_sr_set.get_role_mask().to_dense().to(self.device) writer = SummaryWriter(os.path.join(self.a.out, "exp")) self.a.writer = writer # loading testing data # voa_text = self.a.test_voa_text voa_image_dir = self.a.test_voa_image gt_voa_image = self.a.gt_voa_image gt_voa_text = self.a.gt_voa_text gt_voa_align =self.a.gt_voa_align sr_verb_mapping, sr_role_mapping = train_sr_set.get_sr_mapping() test_m2e2_set = M2E2Dataset(path=gt_voa_text, img_dir=voa_image_dir, fields={"image": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file, object_detection_threshold=self.a.object_detection_threshold, keep_events=self.a.keep_events, ) object_results, object_label, object_detection_threshold = test_m2e2_set.get_object_results() # build batch on cpu test_m2e2_iter = BucketIterator(test_m2e2_set, batch_size=1, train=False, shuffle=False, device=-1, sort_key=lambda x: len(x.POSTAGS)) # scores = 0.0 # now_bad = 0 # restart_used = 0 print("\nStarting testing...\n") # lr = parser.lr # optimizer = optimizer_constructor(lr=lr) # ee_tester = EDTester(LabelField.vocab.itos, EventsField.vocab.itos, self.a.ignore_time_test) # sr_tester = SRTester() # g_tester = GroundingTester() j_tester = JointTester(self.a.ignore_place_sr_test, self.a.ignore_time_test) # if self.a.visual_voa_ee_path is not None: # ee_visualizer = EDVisualizer(self.a.gt_voa_text) # else: # ee_visualizer = None image_gt = json.load(open(gt_voa_image)) # all_y = [] # all_y_ = [] # all_events = [] # all_events_ = [] vision_result = dict() # if self.a.visual_voa_g_path is not None and not os.path.exists(self.a.visual_voa_g_path): # os.makedirs(self.a.visual_voa_g_path, exist_ok=True) # if self.a.visual_voa_ee_path is not None and not os.path.exists(self.a.visual_voa_ee_path): # os.makedirs(self.a.visual_voa_ee_path, exist_ok=True) if self.a.visual_voa_sr_path is not None and not os.path.exists(self.a.visual_voa_sr_path): os.makedirs(self.a.visual_voa_sr_path, exist_ok=True) # grounding_writer = open(self.a.visual_voa_g_path, 'w') doc_done = set() with torch.no_grad(): model.eval() for batch in test_m2e2_iter: vision_result = joint_test_batch( model_g=model, batch_g=batch, device=self.device, transform=transform, img_dir=voa_image_dir, # ee_hyps=self.a.ee_hps, # ee_word_i2s=self.a.ee_word_i2s, # ee_label_i2s=self.a.ee_label_i2s, # ee_role_i2s=self.a.ee_role_i2s, # ee_tester=ee_tester, # ee_visualizer=ee_visualizer, sr_noun_i2s=self.a.sr_word_i2s, sr_verb_i2s=self.a.sr_label_i2s, sr_role_i2s=self.a.sr_role_i2s, # sr_tester=sr_tester, role_masks=self.a.role_masks, # ee_role_mask=self.a.ee_role_mask, # j_tester=j_tester, image_gt=image_gt, verb2type=sr_verb_mapping, role2role=sr_role_mapping, vision_result=vision_result, # all_y=all_y, # all_y_=all_y_, # all_events=all_events, # all_events_=all_events_, # visual_g_path=self.a.visual_voa_g_path, # visual_ee_path=self.a.visual_voa_ee_path, load_object=self.a.add_object, object_results=object_results, object_label=object_label, object_detection_threshold=object_detection_threshold, vocab_objlabel=vocab_noun.word2id, # apply_ee_role_mask=self.a.apply_ee_role_mask keep_events_sr=self.a.keep_events_sr, doc_done=doc_done, ) print('vision_result size', len(vision_result)) # pickle.dump(vision_result, open(os.path.join(self.a.out, 'vision_result.pkl'), 'w')) # ep, er, ef = ee_tester.calculate_report(all_y, all_y_, transform=True) # ap, ar, af = ee_tester.calculate_sets(all_events, all_events_) # if self.a.visual_voa_ee_path is not None: # ee_visualizer.rewrite_brat(self.a.visual_voa_ee_path, self.a.visual_voa_ee_gt_ann) # # print('text ep, er, ef', ep, er, ef) # print('text ap, ar, af', ap, ar, af) evt_p, evt_r, evt_f1, role_scores = j_tester.calculate_report( vision_result, voa_image_dir, self.a.visual_voa_sr_path, self.a.add_object, keep_events_sr=self.a.keep_events_sr )#consts.O_LABEL, consts.ROLE_O_LABEL) print('image event ep, er, ef \n', evt_p, '\n', evt_r, '\n', evt_f1) # if not self.a.add_object: # print('image att_iou ap, ar, af', role_scores['role_att_iou_p'], role_scores['role_att_iou_r'], # role_scores['role_att_iou_f1']) # print('image att_hit ap, ar, af', role_scores['role_att_hit_p'], role_scores['role_att_hit_r'], # role_scores['role_att_hit_f1']) # print('image att_cor ap, ar, af', role_scores['role_att_cor_p'], role_scores['role_att_cor_r'], # role_scores['role_att_cor_f1']) # else: # print('image obj_iou ap, ar, af', role_scores['role_obj_iou_p'], role_scores['role_obj_iou_r'], # role_scores['role_obj_iou_f1']) # print('image obj_iou_union ap, ar, af', role_scores['role_obj_iou_union_p'], role_scores['role_obj_iou_union_r'], # role_scores['role_obj_iou_union_f1']) for key in role_scores: print(key) for key_ in role_scores[key]: print(key_, role_scores[key][key_])
def run(self): print("Running on", self.a.device) self.set_device(self.a.device) np.random.seed(self.a.seed) torch.manual_seed(self.a.seed) torch.backends.cudnn.benchmark = True # create training set if self.a.test_ee: log('loading event extraction corpus from %s' % self.a.test_ee) WordsField = Field(lower=True, include_lengths=True, batch_first=True) PosTagsField = Field(lower=True, batch_first=True) EntityLabelsField = MultiTokenField(lower=False, batch_first=True) AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True) LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None) EventsField = EventField(lower=False, batch_first=True) EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False) if self.a.amr: colcc = 'amr-colcc' else: colcc = 'stanford-colcc' print(colcc) train_ee_set = ACE2005Dataset(path=self.a.train_ee, fields={"words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=1) dev_ee_set = ACE2005Dataset(path=self.a.dev_ee, fields={"words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=0) test_ee_set = ACE2005Dataset(path=self.a.test_ee, fields={"words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=0) if self.a.webd: pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15)) WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, vectors=pretrained_embedding) else: WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS) PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS) EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS) LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL) EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT) consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME] # print("O label is", consts.O_LABEL) consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME] # print("O label for AE is", consts.ROLE_O_LABEL) self.a.label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5 self.a.label_weight[consts.O_LABEL] = 1.0 self.a.arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5 # add role mask self.a.role_mask = event_role_mask(self.a.test_ee, self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi, EventsField.vocab.stoi, self.device) # print('self.a.hps', self.a.hps) if not self.a.hps_path: self.a.hps = eval(self.a.hps) if "wemb_size" not in self.a.hps: self.a.hps["wemb_size"] = len(WordsField.vocab.itos) if "pemb_size" not in self.a.hps: self.a.hps["pemb_size"] = len(PosTagsField.vocab.itos) if "psemb_size" not in self.a.hps: self.a.hps["psemb_size"] = max([train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest()]) + 2 if "eemb_size" not in self.a.hps: self.a.hps["eemb_size"] = len(EntityLabelsField.vocab.itos) if "oc" not in self.a.hps: self.a.hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.hps: self.a.hps["ae_oc"] = len(EventsField.vocab.itos) tester = self.get_tester(LabelField.vocab.itos, EventsField.vocab.itos) if self.a.finetune: log('init model from ' + self.a.finetune) model = load_ee_model(self.a.hps, self.a.finetune, WordsField.vocab.vectors, self.device) log('model loaded, there are %i sets of params' % len(model.parameters_requires_grads())) else: model = load_ee_model(self.a.hps, None, WordsField.vocab.vectors, self.device) log('model created from scratch, there are %i sets of params' % len(model.parameters_requires_grads())) self.a.word_i2s = WordsField.vocab.itos self.a.label_i2s = LabelField.vocab.itos self.a.role_i2s = EventsField.vocab.itos writer = SummaryWriter(os.path.join(self.a.out, "exp")) self.a.writer = writer # train_iter = BucketIterator(train_ee_set, batch_size=self.a.batch, # train=True, shuffle=False, device=-1, # sort_key=lambda x: len(x.POSTAGS)) # dev_iter = BucketIterator(dev_ee_set, batch_size=self.a.batch, train=False, # shuffle=False, device=-1, # sort_key=lambda x: len(x.POSTAGS)) test_iter = BucketIterator(test_ee_set, batch_size=self.a.batch, train=False, shuffle=False, device=-1, sort_key=lambda x: len(x.POSTAGS)) print("\nStarting testing ...\n") # Testing Phrase test_loss, test_ed_p, test_ed_r, test_ed_f1, \ test_ae_p, test_ae_r, test_ae_f1 = run_over_data(data_iter=test_iter, optimizer=None, model=model, need_backward=False, MAX_STEP=ceil(len( test_ee_set) / self.a.batch), tester=tester, hyps=model.hyperparams, device=model.device, maxnorm=self.a.maxnorm, word_i2s=self.a.word_i2s, label_i2s=self.a.label_i2s, role_i2s=self.a.role_i2s, weight=self.a.label_weight, arg_weight=self.a.arg_weight, save_output=os.path.join( self.a.out, "test_final.txt"), role_mask=self.a.role_mask) print("\nFinally test loss: ", test_loss, "\ntest ed p: ", test_ed_p, " test ed r: ", test_ed_r, " test ed f1: ", test_ed_f1, "\ntest ae p: ", test_ae_p, " test ae r: ", test_ae_r, " test ae f1: ", test_ae_f1)
def run(self): print("Running on", self.a.device) self.set_device(self.a.device) np.random.seed(self.a.seed) torch.manual_seed(self.a.seed) torch.backends.cudnn.benchmark = True #################### loading event extraction dataset #################### if self.a.train_ee: log('loading event extraction corpus from %s' % self.a.train_ee) # both for grounding and ee WordsField = Field(lower=True, include_lengths=True, batch_first=True) PosTagsField = Field(lower=True, batch_first=True) EntityLabelsField = MultiTokenField(lower=False, batch_first=True) AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True) EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False) # only for ee LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None) EventsField = EventField(lower=False, batch_first=True) SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) if self.a.amr: colcc = 'simple-parsing' else: colcc = 'combined-parsing' print(colcc) train_ee_set = ACE2005Dataset(path=self.a.train_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=1) dev_ee_set = ACE2005Dataset(path=self.a.dev_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=0) test_ee_set = ACE2005Dataset(path=self.a.test_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=0) if self.a.webd: pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15)) LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL, vectors=pretrained_embedding) EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT, vectors=pretrained_embedding) else: LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL) EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT) # add role mask self.a.role_mask = event_role_mask(self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi, EventsField.vocab.stoi, self.device) #################### loading SR dataset #################### # both for grounding and sr if self.a.train_sr: log('loading corpus from %s' % self.a.train_sr) transform = transforms.Compose([ transforms.Resize(256), transforms.RandomHorizontalFlip(), transforms.RandomCrop(224), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) vocab_noun = Vocab(os.path.join(self.a.vocab, 'vocab_situation_noun.pkl'), load=True) vocab_role = Vocab(os.path.join(self.a.vocab, 'vocab_situation_role.pkl'), load=True) vocab_verb = Vocab(os.path.join(self.a.vocab, 'vocab_situation_verb.pkl'), load=True) # train_sr_loader = imsitu_loader(self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, self.a.imsitu_ontology_file, # self.a.train_sr, self.a.verb_mapping_file, self.a.role_mapping_file, # self.a.object_class_map_file, self.a.object_detection_pkl_file, # self.a.object_detection_threshold, # transform, self.a.batch, shuffle=self.a.shuffle, num_workers=1) #self.a.shuffle # dev_sr_loader = imsitu_loader(self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, self.a.imsitu_ontology_file, # self.a.dev_sr, self.a.verb_mapping_file, self.a.role_mapping_file, # self.a.object_class_map_file, self.a.object_detection_pkl_file, # self.a.object_detection_threshold, # transform, self.a.batch, shuffle=self.a.shuffle, num_workers=1) # test_sr_loader = imsitu_loader(self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, self.a.imsitu_ontology_file, # self.a.test_sr, self.a.verb_mapping_file, self.a.role_mapping_file, # self.a.object_class_map_file, self.a.object_detection_pkl_file, # self.a.object_detection_threshold, # transform, self.a.batch, shuffle=self.a.shuffle, num_workers=1) train_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb, LabelField.vocab.stoi, EventsField.vocab.stoi, self.a.imsitu_ontology_file, self.a.train_sr, self.a.verb_mapping_file, self.a.object_class_map_file, self.a.object_detection_pkl_file, self.a.object_detection_threshold, transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, load_object=self.a.add_object, filter_place=self.a.filter_place) dev_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb, LabelField.vocab.stoi, EventsField.vocab.stoi, self.a.imsitu_ontology_file, self.a.dev_sr, self.a.verb_mapping_file, self.a.object_class_map_file, self.a.object_detection_pkl_file, self.a.object_detection_threshold, transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, load_object=self.a.add_object, filter_place=self.a.filter_place) test_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb, LabelField.vocab.stoi, EventsField.vocab.stoi, self.a.imsitu_ontology_file, self.a.test_sr, self.a.verb_mapping_file, self.a.object_class_map_file, self.a.object_detection_pkl_file, self.a.object_detection_threshold, transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs, load_object=self.a.add_object, filter_place=self.a.filter_place) #################### loading grounding dataset #################### if self.a.train_grounding: log('loading grounding corpus from %s' % self.a.train_grounding) # only for grounding IMAGEIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True) # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True) train_grounding_set = GroundingDataset(path=self.a.train_grounding, img_dir=self.a.img_dir_grounding, fields={"id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file_g, object_detection_threshold=self.a.object_detection_threshold, ) dev_grounding_set = GroundingDataset(path=self.a.dev_grounding, img_dir=self.a.img_dir_grounding, fields={"id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file_g, object_detection_threshold=self.a.object_detection_threshold, ) test_grounding_set = GroundingDataset(path=self.a.test_grounding, img_dir=self.a.img_dir_grounding, fields={"id": ("IMAGEID", IMAGEIDField), "sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "all-entities": ("ENTITIES", EntitiesField), # "image": ("IMAGE", IMAGEField), }, transform=transform, amr=self.a.amr, load_object=self.a.add_object, object_ontology_file=self.a.object_class_map_file, object_detection_pkl_file=self.a.object_detection_pkl_file_g, object_detection_threshold=self.a.object_detection_threshold, ) #################### build vocabulary #################### if self.a.webd: pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15)) WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS, vectors=pretrained_embedding) else: WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS) PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS, train_grounding_set.POSTAGS, dev_grounding_set.POSTAGS) EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS, train_grounding_set.ENTITYLABELS, dev_grounding_set.ENTITYLABELS) consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME] # print("O label is", consts.O_LABEL) consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME] # print("O label for AE is", consts.ROLE_O_LABEL) dev_ee_set1 = ACE2005Dataset(path=self.a.dev_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=1, only_keep=True) test_ee_set1 = ACE2005Dataset(path=self.a.test_ee, fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField), "pos-tags": ("POSTAGS", PosTagsField), "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField), colcc: ("ADJM", AdjMatrixField), "golden-event-mentions": ("LABEL", LabelField), "all-events": ("EVENT", EventsField), "all-entities": ("ENTITIES", EntitiesField)}, amr=self.a.amr, keep_events=1, only_keep=True) print("train set length", len(train_ee_set)) print("dev set length", len(dev_ee_set)) print("dev set 1/1 length", len(dev_ee_set1)) print("test set length", len(test_ee_set)) print("test set 1/1 length", len(test_ee_set1)) # sr model initialization if not self.a.sr_hps_path: self.a.sr_hps = eval(self.a.sr_hps) embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to(self.device) embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to(self.device) embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to(self.device) if "wvemb_size" not in self.a.sr_hps: self.a.sr_hps["wvemb_size"] = len(vocab_verb.id2word) if "wremb_size" not in self.a.sr_hps: self.a.sr_hps["wremb_size"] = len(vocab_role.id2word) if "wnemb_size" not in self.a.sr_hps: self.a.sr_hps["wnemb_size"] = len(vocab_noun.id2word) self.a.ee_label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5 self.a.ee_label_weight[consts.O_LABEL] = 1.0 self.a.ee_arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5 self.a.ee_hps = eval(self.a.ee_hps) if "wemb_size" not in self.a.ee_hps: self.a.ee_hps["wemb_size"] = len(WordsField.vocab.itos) if "pemb_size" not in self.a.ee_hps: self.a.ee_hps["pemb_size"] = len(PosTagsField.vocab.itos) if "psemb_size" not in self.a.ee_hps: # self.a.ee_hps["psemb_size"] = max([train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2 self.a.ee_hps["psemb_size"] = max([train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest(), train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2 if "eemb_size" not in self.a.ee_hps: self.a.ee_hps["eemb_size"] = len(EntityLabelsField.vocab.itos) if "oc" not in self.a.ee_hps: self.a.ee_hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.ee_hps: self.a.ee_hps["ae_oc"] = len(EventsField.vocab.itos) if "oc" not in self.a.sr_hps: self.a.sr_hps["oc"] = len(LabelField.vocab.itos) if "ae_oc" not in self.a.sr_hps: self.a.sr_hps["ae_oc"] = len(EventsField.vocab.itos) ee_tester = EDTester(LabelField.vocab.itos, EventsField.vocab.itos, self.a.ignore_time_test) sr_tester = SRTester() g_tester = GroundingTester() j_tester = JointTester(self.a.ignore_place_sr_test, self.a.ignore_time_test) ace_classifier = ACEClassifier(2 * self.a.ee_hps["lstm_dim"], self.a.ee_hps["oc"], self.a.ee_hps["ae_oc"], self.device) if self.a.finetune_ee: log('init ee model from ' + self.a.finetune_ee) ee_model = load_ee_model(self.a.ee_hps, self.a.finetune_ee, WordsField.vocab.vectors, self.device, ace_classifier) log('ee model loaded, there are %i sets of params' % len(ee_model.parameters_requires_grads())) else: ee_model = load_ee_model(self.a.ee_hps, None, WordsField.vocab.vectors, self.device, ace_classifier) log('ee model created from scratch, there are %i sets of params' % len(ee_model.parameters_requires_grads())) if self.a.finetune_sr: log('init sr model from ' + self.a.finetune_sr) sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, self.a.finetune_sr, self.device, ace_classifier, add_object=self.a.add_object, load_partial=True) log('sr model loaded, there are %i sets of params' % len(sr_model.parameters_requires_grads())) else: sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, None, self.device, ace_classifier, add_object=self.a.add_object, load_partial=True) log('sr model created from scratch, there are %i sets of params' % len(sr_model.parameters_requires_grads())) model = GroundingModel(ee_model, sr_model, self.get_device()) # ee_model = torch.nn.DataParallel(ee_model) # sr_model = torch.nn.DataParallel(sr_model) # model = torch.nn.DataParallel(model) if self.a.optimizer == "adadelta": optimizer_constructor = partial(torch.optim.Adadelta, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay) elif self.a.optimizer == "adam": optimizer_constructor = partial(torch.optim.Adam, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay) else: optimizer_constructor = partial(torch.optim.SGD, params=model.parameters_requires_grads(), weight_decay=self.a.l2decay, momentum=0.9) log('optimizer in use: %s' % str(self.a.optimizer)) if not os.path.exists(self.a.out): os.mkdir(self.a.out) with open(os.path.join(self.a.out, "word.vec"), "wb") as f: pickle.dump(WordsField.vocab, f) with open(os.path.join(self.a.out, "pos.vec"), "wb") as f: pickle.dump(PosTagsField.vocab.stoi, f) with open(os.path.join(self.a.out, "entity.vec"), "wb") as f: pickle.dump(EntityLabelsField.vocab.stoi, f) with open(os.path.join(self.a.out, "label.vec"), "wb") as f: pickle.dump(LabelField.vocab.stoi, f) with open(os.path.join(self.a.out, "role.vec"), "wb") as f: pickle.dump(EventsField.vocab.stoi, f) with open(os.path.join(self.a.out, "ee_hyps.json"), "w") as f: json.dump(self.a.ee_hps, f) with open(os.path.join(self.a.out, "sr_hyps.json"), "w") as f: json.dump(self.a.sr_hps, f) log('init complete\n') # ee mappings self.a.ee_word_i2s = WordsField.vocab.itos self.a.ee_label_i2s = LabelField.vocab.itos self.a.ee_role_i2s = EventsField.vocab.itos # sr mappings self.a.sr_word_i2s = vocab_noun.id2word self.a.sr_label_i2s = vocab_verb.id2word # LabelField.vocab.itos self.a.sr_role_i2s = vocab_role.id2word writer = SummaryWriter(os.path.join(self.a.out, "exp")) self.a.writer = writer joint_train( model_ee=ee_model, model_sr=sr_model, model_g=model, train_set_g=train_grounding_set, dev_set_g=dev_grounding_set, test_set_g=test_grounding_set, train_set_ee=train_ee_set, dev_set_ee=dev_ee_set, test_set_ee=test_ee_set, train_set_sr=train_sr_set, dev_set_sr=dev_sr_set, test_set_sr=test_sr_set, optimizer_constructor=optimizer_constructor, epochs=self.a.epochs, ee_tester=ee_tester, sr_tester=sr_tester, g_tester=g_tester, j_tester=j_tester, parser=self.a, other_testsets={ "dev ee 1/1": dev_ee_set1, "test ee 1/1": test_ee_set1, }, transform=transform, vocab_objlabel=vocab_noun.word2id ) log('Done!')
]]), torch.FloatTensor([ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ]), torch.Size([BATCH_SIZE, ET, SEQ_LEN, SEQ_LEN])).to_dense().to(device) input = torch.randn(BATCH_SIZE, SEQ_LEN, D).to(device) label = torch.LongTensor([0, 1, 0, 1, 0, 1, 0, 1]).to(device) cc = GraphConvolution(in_features=D, out_features=D, edge_types=ET, device=device, use_bn=True) oo = BottledOrthogonalLinear(in_features=D, out_features=CLASSN).to(device) optimizer = Adadelta(list(cc.parameters()) + list(oo.parameters())) aloss = 1e9 df = 1e9 while df > 1e-7: output = oo(cc(input, adj)).view(BATCH_SIZE * SEQ_LEN, CLASSN) loss = F.cross_entropy(output, label) df = abs(aloss - loss.item()) aloss = loss.item() loss.backward() optimizer.step() log(aloss) log(F.softmax(output), dim=2)