Пример #1
0
    def run(self):
        print("Running on", self.a.device, ', optimizer=', self.a.optimizer,
              ', lr=%d' % self.a.lr)
        self.set_device(self.a.device)

        np.random.seed(self.a.seed)
        torch.manual_seed(self.a.seed)
        torch.backends.cudnn.benchmark = True
        # torch.backends.cudnn.deterministic = False

        # create training set
        if self.a.train_ee:
            log('loading event extraction corpus from %s' % self.a.train_ee)

        WordsField = Field(lower=True, include_lengths=True, batch_first=True)
        PosTagsField = Field(lower=True, batch_first=True)
        EntityLabelsField = MultiTokenField(lower=False, batch_first=True)
        AdjMatrixField = SparseField(sequential=False,
                                     use_vocab=False,
                                     batch_first=True)
        LabelField = Field(lower=False,
                           batch_first=True,
                           pad_token='0',
                           unk_token=None)
        EventsField = EventField(lower=False, batch_first=True)
        EntitiesField = EntityField(lower=False,
                                    batch_first=True,
                                    use_vocab=False)

        if self.a.amr:
            colcc = 'amr-colcc'
        else:
            colcc = 'stanford-colcc'
        print(colcc)

        train_ee_set = ACE2005Dataset(path=self.a.train_ee,
                                      fields={
                                          "words": ("WORDS", WordsField),
                                          "pos-tags":
                                          ("POSTAGS", PosTagsField),
                                          "golden-entity-mentions":
                                          ("ENTITYLABELS", EntityLabelsField),
                                          colcc: ("ADJM", AdjMatrixField),
                                          "golden-event-mentions":
                                          ("LABEL", LabelField),
                                          "all-events": ("EVENT", EventsField),
                                          "all-entities":
                                          ("ENTITIES", EntitiesField)
                                      },
                                      amr=self.a.amr,
                                      keep_events=1)

        dev_ee_set = ACE2005Dataset(path=self.a.dev_ee,
                                    fields={
                                        "words": ("WORDS", WordsField),
                                        "pos-tags": ("POSTAGS", PosTagsField),
                                        "golden-entity-mentions":
                                        ("ENTITYLABELS", EntityLabelsField),
                                        colcc: ("ADJM", AdjMatrixField),
                                        "golden-event-mentions":
                                        ("LABEL", LabelField),
                                        "all-events": ("EVENT", EventsField),
                                        "all-entities":
                                        ("ENTITIES", EntitiesField)
                                    },
                                    amr=self.a.amr,
                                    keep_events=0)

        test_ee_set = ACE2005Dataset(path=self.a.test_ee,
                                     fields={
                                         "words": ("WORDS", WordsField),
                                         "pos-tags": ("POSTAGS", PosTagsField),
                                         "golden-entity-mentions":
                                         ("ENTITYLABELS", EntityLabelsField),
                                         colcc: ("ADJM", AdjMatrixField),
                                         "golden-event-mentions":
                                         ("LABEL", LabelField),
                                         "all-events": ("EVENT", EventsField),
                                         "all-entities":
                                         ("ENTITIES", EntitiesField)
                                     },
                                     amr=self.a.amr,
                                     keep_events=0)

        if self.a.webd:
            pretrained_embedding = Vectors(self.a.webd,
                                           ".",
                                           unk_init=partial(
                                               torch.nn.init.uniform_,
                                               a=-0.15,
                                               b=0.15))
            WordsField.build_vocab(train_ee_set.WORDS,
                                   dev_ee_set.WORDS,
                                   vectors=pretrained_embedding)
            LabelField.build_vocab(train_ee_set.LABEL,
                                   dev_ee_set.LABEL,
                                   vectors=pretrained_embedding)
            EventsField.build_vocab(train_ee_set.EVENT,
                                    dev_ee_set.EVENT,
                                    vectors=pretrained_embedding)
        else:
            WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS)
            LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL)
            EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT)
        PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS)
        EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS,
                                      dev_ee_set.ENTITYLABELS)

        consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME]
        # print("O label is", consts.O_LABEL)
        consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME]
        # print("O label for AE is", consts.ROLE_O_LABEL)

        dev_ee_set1 = ACE2005Dataset(path=self.a.dev_ee,
                                     fields={
                                         "words": ("WORDS", WordsField),
                                         "pos-tags": ("POSTAGS", PosTagsField),
                                         "golden-entity-mentions":
                                         ("ENTITYLABELS", EntityLabelsField),
                                         colcc: ("ADJM", AdjMatrixField),
                                         "golden-event-mentions":
                                         ("LABEL", LabelField),
                                         "all-events": ("EVENT", EventsField),
                                         "all-entities":
                                         ("ENTITIES", EntitiesField)
                                     },
                                     amr=self.a.amr,
                                     keep_events=1,
                                     only_keep=True)

        test_ee_set1 = ACE2005Dataset(path=self.a.test_ee,
                                      fields={
                                          "words": ("WORDS", WordsField),
                                          "pos-tags":
                                          ("POSTAGS", PosTagsField),
                                          "golden-entity-mentions":
                                          ("ENTITYLABELS", EntityLabelsField),
                                          colcc: ("ADJM", AdjMatrixField),
                                          "golden-event-mentions":
                                          ("LABEL", LabelField),
                                          "all-events": ("EVENT", EventsField),
                                          "all-entities":
                                          ("ENTITIES", EntitiesField)
                                      },
                                      amr=self.a.amr,
                                      keep_events=1,
                                      only_keep=True)
        print("train set length", len(train_ee_set))

        print("dev set length", len(dev_ee_set))
        print("dev set 1/1 length", len(dev_ee_set1))

        print("test set length", len(test_ee_set))
        print("test set 1/1 length", len(test_ee_set1))

        self.a.label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5
        self.a.label_weight[consts.O_LABEL] = 1.0
        self.a.arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5
        self.a.arg_weight[consts.ROLE_O_LABEL] = 1.0  #????
        # add role mask
        # self.a.role_mask = event_role_mask(self.a.test_ee, self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi, EventsField.vocab.stoi, self.device)
        self.a.role_mask = None

        self.a.hps = eval(self.a.hps)
        if "wemb_size" not in self.a.hps:
            self.a.hps["wemb_size"] = len(WordsField.vocab.itos)
        if "pemb_size" not in self.a.hps:
            self.a.hps["pemb_size"] = len(PosTagsField.vocab.itos)
        if "psemb_size" not in self.a.hps:
            self.a.hps["psemb_size"] = max([
                train_ee_set.longest(),
                dev_ee_set.longest(),
                test_ee_set.longest()
            ]) + 2
        if "eemb_size" not in self.a.hps:
            self.a.hps["eemb_size"] = len(EntityLabelsField.vocab.itos)
        if "oc" not in self.a.hps:
            self.a.hps["oc"] = len(LabelField.vocab.itos)
        if "ae_oc" not in self.a.hps:
            self.a.hps["ae_oc"] = len(EventsField.vocab.itos)

        tester = self.get_tester(LabelField.vocab.itos, EventsField.vocab.itos)

        if self.a.finetune:
            log('init model from ' + self.a.finetune)
            model = load_ee_model(self.a.hps, self.a.finetune,
                                  WordsField.vocab.vectors, self.device)
            log('model loaded, there are %i sets of params' %
                len(model.parameters_requires_grads()))
            log(model.parameters_requires_grads())
        else:
            model = load_ee_model(self.a.hps, None, WordsField.vocab.vectors,
                                  self.device)
            log('model created from scratch, there are %i sets of params' %
                len(model.parameters_requires_grads()))
            log(model.parameters_requires_grads())

        if self.a.optimizer == "adadelta":
            optimizer_constructor = partial(
                torch.optim.Adadelta,
                params=model.parameters_requires_grads(),
                weight_decay=self.a.l2decay)
        elif self.a.optimizer == "adam":
            optimizer_constructor = partial(
                torch.optim.Adam,
                params=model.parameters_requires_grads(),
                weight_decay=self.a.l2decay)
        else:
            optimizer_constructor = partial(
                torch.optim.SGD,
                params=model.parameters_requires_grads(),
                weight_decay=self.a.l2decay,
                momentum=0.9)

        log('optimizer in use: %s' % str(self.a.optimizer))

        if not os.path.exists(self.a.out):
            os.mkdir(self.a.out)
        with open(os.path.join(self.a.out, "word.vec"), "wb") as f:
            pickle.dump(WordsField.vocab, f)
        with open(os.path.join(self.a.out, "pos.vec"), "wb") as f:
            pickle.dump(PosTagsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "entity.vec"), "wb") as f:
            pickle.dump(EntityLabelsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "label_s2i.vec"), "wb") as f:
            pickle.dump(LabelField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "role_s2i.vec"), "wb") as f:
            pickle.dump(EventsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "label_i2s.vec"), "wb") as f:
            pickle.dump(LabelField.vocab.itos, f)
        with open(os.path.join(self.a.out, "role_i2s.vec"), "wb") as f:
            pickle.dump(EventsField.vocab.itos, f)
        with open(os.path.join(self.a.out, "ee_hyps.json"), "w") as f:
            json.dump(self.a.hps, f)

        log('init complete\n')

        self.a.word_i2s = WordsField.vocab.itos
        self.a.label_i2s = LabelField.vocab.itos
        self.a.role_i2s = EventsField.vocab.itos
        writer = SummaryWriter(os.path.join(self.a.out, "exp"))
        self.a.writer = writer

        ee_train(model=model,
                 train_set=train_ee_set,
                 dev_set=dev_ee_set,
                 test_set=test_ee_set,
                 optimizer_constructor=optimizer_constructor,
                 epochs=self.a.epochs,
                 tester=tester,
                 parser=self.a,
                 other_testsets={
                     "dev 1/1": dev_ee_set1,
                     "test 1/1": test_ee_set1,
                 },
                 role_mask=self.a.role_mask)
        log('Done!')
Пример #2
0
    def run(self):
        print("Running on", self.a.device)
        self.set_device(self.a.device)

        np.random.seed(self.a.seed)
        torch.manual_seed(self.a.seed)
        torch.backends.cudnn.benchmark = True

        # build text event vocab and ee_role vocab
        WordsField = Field(lower=True, include_lengths=True, batch_first=True)
        PosTagsField = Field(lower=True, batch_first=True)
        EntityLabelsField = MultiTokenField(lower=False, batch_first=True)
        AdjMatrixField = SparseField(sequential=False,
                                     use_vocab=False,
                                     batch_first=True)
        EntitiesField = EntityField(lower=False,
                                    batch_first=True,
                                    use_vocab=False)
        # only for ee
        LabelField = Field(lower=False,
                           batch_first=True,
                           pad_token='0',
                           unk_token=None)
        EventsField = EventField(lower=False, batch_first=True)
        SENTIDField = SparseField(sequential=False,
                                  use_vocab=False,
                                  batch_first=True)
        colcc = 'combined-parsing'
        train_ee_set = ACE2005Dataset(
            path=self.a.train_ee,
            fields={
                "sentence_id": ("SENTID", SENTIDField),
                "words": ("WORDS", WordsField),
                "pos-tags": ("POSTAGS", PosTagsField),
                "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                colcc: ("ADJM", AdjMatrixField),
                "golden-event-mentions": ("LABEL", LabelField),
                "all-events": ("EVENT", EventsField),
                "all-entities": ("ENTITIES", EntitiesField)
            },
            amr=False,
            keep_events=1)
        pretrained_embedding = Vectors(self.a.webd,
                                       ".",
                                       unk_init=partial(torch.nn.init.uniform_,
                                                        a=-0.15,
                                                        b=0.15))
        LabelField.build_vocab(train_ee_set.LABEL,
                               vectors=pretrained_embedding)
        EventsField.build_vocab(train_ee_set.EVENT,
                                vectors=pretrained_embedding)

        # consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME]
        # # print("O label is", consts.O_LABEL)
        # consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME]
        # # print("O label for AE is", consts.ROLE_O_LABEL)

        # create testing set
        if self.a.test_sr:
            log('loading corpus from %s' % self.a.test_sr)

        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

        vocab_noun = Vocab(os.path.join(self.a.vocab,
                                        'vocab_situation_noun.pkl'),
                           load=True)
        vocab_role = Vocab(os.path.join(self.a.vocab,
                                        'vocab_situation_role.pkl'),
                           load=True)
        vocab_verb = Vocab(os.path.join(self.a.vocab,
                                        'vocab_situation_verb.pkl'),
                           load=True)

        # train_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb,
        #                              EventsField.vocab.stoi, LabelField.vocab.stoi,
        #                          self.a.imsitu_ontology_file,
        #                          self.a.train_sr, self.a.verb_mapping_file,
        #                          self.a.object_class_map_file, self.a.object_detection_pkl_file,
        #                          self.a.object_detection_threshold,
        #                          transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
        #                              load_object=self.a.add_object, filter_place=self.a.filter_place)
        # dev_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb,
        #                            EventsField.vocab.stoi, LabelField.vocab.stoi,
        #                          self.a.imsitu_ontology_file,
        #                          self.a.dev_sr, self.a.verb_mapping_file,
        #                          self.a.object_class_map_file, self.a.object_detection_pkl_file,
        #                          self.a.object_detection_threshold,
        #                          transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
        #                            load_object=self.a.add_object, filter_place=self.a.filter_place)
        test_sr_set = ImSituDataset(
            self.a.image_dir,
            vocab_noun,
            vocab_role,
            vocab_verb,
            EventsField.vocab.stoi,
            LabelField.vocab.stoi,
            self.a.imsitu_ontology_file,
            self.a.test_sr,
            self.a.verb_mapping_file,
            self.a.object_class_map_file,
            self.a.object_detection_pkl_file,
            self.a.object_detection_threshold,
            transform,
            filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
            load_object=self.a.add_object,
            filter_place=self.a.filter_place)

        embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to(
            self.device)
        embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to(
            self.device)
        embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to(
            self.device)
        # consts.O_LABEL = vocab_verb.word2id['0'] # verb??
        # consts.ROLE_O_LABEL = vocab_role.word2id["OTHER"] #???

        # self.a.label_weight = torch.ones([len(vocab_sr.id2word)]) * 5 # more important to learn
        # self.a.label_weight[consts.O_LABEL] = 1.0 #???

        if not self.a.hps_path:
            self.a.hps = eval(self.a.hps)
        if self.a.textontology:
            if "wvemb_size" not in self.a.hps:
                self.a.hps["wvemb_size"] = len(LabelField.vocab.stoi)
            if "wremb_size" not in self.a.hps:
                self.a.hps["wremb_size"] = len(EventsField.vocab.itos)
            if "wnemb_size" not in self.a.hps:
                self.a.hps["wnemb_size"] = len(vocab_noun.id2word)
            if "oc" not in self.a.hps:
                self.a.hps["oc"] = len(LabelField.vocab.itos)
            if "ae_oc" not in self.a.hps:
                self.a.hps["ae_oc"] = len(EventsField.vocab.itos)
        else:
            if "wvemb_size" not in self.a.hps:
                self.a.hps["wvemb_size"] = len(vocab_verb.id2word)
            if "wremb_size" not in self.a.hps:
                self.a.hps["wremb_size"] = len(vocab_role.id2word)
            if "wnemb_size" not in self.a.hps:
                self.a.hps["wnemb_size"] = len(vocab_noun.id2word)
            if "oc" not in self.a.hps:
                self.a.hps["oc"] = len(LabelField.vocab.itos)
            if "ae_oc" not in self.a.hps:
                self.a.hps["ae_oc"] = len(EventsField.vocab.itos)

        tester = self.get_tester()

        if self.a.textontology:
            if self.a.finetune:
                log('init model from ' + self.a.finetune)
                model = load_sr_model(self.a.hps,
                                      embeddingMatrix_noun,
                                      LabelField.vocab.vectors,
                                      EventsField.vocab.vectors,
                                      self.a.finetune,
                                      self.device,
                                      add_object=self.a.add_object)
                log('sr model loaded, there are %i sets of params' %
                    len(model.parameters_requires_grads()))
            else:
                model = load_sr_model(self.a.hps,
                                      embeddingMatrix_noun,
                                      LabelField.vocab.vectors,
                                      EventsField.vocab.vectors,
                                      None,
                                      self.device,
                                      add_object=self.a.add_object)
                log('sr model created from scratch, there are %i sets of params'
                    % len(model.parameters_requires_grads()))
        else:
            if self.a.finetune:
                log('init model from ' + self.a.finetune)
                model = load_sr_model(self.a.hps,
                                      embeddingMatrix_noun,
                                      embeddingMatrix_verb,
                                      embeddingMatrix_role,
                                      self.a.finetune,
                                      self.device,
                                      add_object=self.a.add_object)
                log('sr model loaded, there are %i sets of params' %
                    len(model.parameters_requires_grads()))
            else:
                model = load_sr_model(self.a.hps,
                                      embeddingMatrix_noun,
                                      embeddingMatrix_verb,
                                      embeddingMatrix_role,
                                      None,
                                      self.device,
                                      add_object=self.a.add_object)
                log('sr model created from scratch, there are %i sets of params'
                    % len(model.parameters_requires_grads()))

        # for name, para in model.named_parameters():
        #     if para.requires_grad:
        #         print(name)
        # exit(1)

        log('init complete\n')

        if not os.path.exists(self.a.out):
            os.mkdir(self.a.out)

        self.a.word_i2s = vocab_noun.id2word
        # if self.a.textontology:
        self.a.acelabel_i2s = LabelField.vocab.itos
        self.a.acerole_i2s = EventsField.vocab.itos
        # with open(os.path.join(self.a.out, "label_s2i.vec"), "wb") as f:
        #     pickle.dump(LabelField.vocab.stoi, f)
        # with open(os.path.join(self.a.out, "role_s2i.vec"), "wb") as f:
        #     pickle.dump(EventsField.vocab.stoi, f)
        # with open(os.path.join(self.a.out, "label_i2s.vec"), "wb") as f:
        #     pickle.dump(LabelField.vocab.itos, f)
        # with open(os.path.join(self.a.out, "role_i2s.vec"), "wb") as f:
        #     pickle.dump(EventsField.vocab.itos, f)
        # else:
        self.a.label_i2s = vocab_verb.id2word  #LabelField.vocab.itos
        self.a.role_i2s = vocab_role.id2word
        # save as Vocab
        writer = SummaryWriter(os.path.join(self.a.out, "exp"))
        self.a.writer = writer

        # with open(os.path.join(self.a.out, "sr_hyps.json"), "w") as f:
        #     json.dump(self.a.hps, f)

        test_iter = torch.utils.data.DataLoader(dataset=test_sr_set,
                                                batch_size=self.a.batch,
                                                shuffle=False,
                                                num_workers=2,
                                                collate_fn=image_collate_fn)

        verb_roles = test_sr_set.get_verb_role_mapping()

        if 'visualize_path' not in self.a:
            visualize_path = None
        else:
            visualize_path = self.a.visualize_path

        test_loss, test_verb_p, test_verb_r, test_verb_f1, \
        test_role_p, test_role_r, test_role_f1, \
        test_noun_p, test_noun_r, test_noun_f1, \
        test_triple_p, test_triple_r, test_triple_f1, \
        test_noun_p_relaxed, test_noun_r_relaxed, test_noun_f1_relaxed, \
        test_triple_p_relaxed, test_triple_r_relaxed, test_triple_f1_relaxed = run_over_data_sr(data_iter=test_iter,
                                                                                                optimizer=None,
                                                                                                model=model,
                                                                                                need_backward=False,
                                                                                                MAX_STEP=ceil(len(
                                                                                                    test_sr_set) / self.a.batch),
                                                                                                tester=tester,
                                                                                                hyps=model.hyperparams,
                                                                                                device=model.device,
                                                                                                maxnorm=self.a.maxnorm,
                                                                                                word_i2s=self.a.word_i2s,
                                                                                                label_i2s=self.a.label_i2s,
                                                                                                role_i2s=self.a.role_i2s,
                                                                                                verb_roles=verb_roles,
                                                                                                load_object=self.a.add_object,
                                                                                                visualize_path=visualize_path,
                                                                                                save_output=os.path.join(
                                                                                                    self.a.out,
                                                                                                    "test_final.txt"))
        print("\nFinally test loss: ", test_loss, "\ntest verb p: ",
              test_verb_p, " test verb r: ", test_verb_r, " test verb f1: ",
              test_verb_f1, "\ntest role p: ", test_role_p, " test role r: ",
              test_role_r, " test role f1: ", test_role_f1, "\ntest noun p: ",
              test_noun_p, " test noun r: ", test_noun_r, " test noun f1: ",
              test_noun_f1, "\ntest triple p: ", test_triple_p,
              " test triple r: ", test_triple_r, " test triple f1: ",
              test_triple_f1, "\ntest noun p relaxed: ", test_noun_p_relaxed,
              " test noun r relaxed: ", test_noun_r_relaxed,
              " test noun f1 relaxed: ", test_noun_f1_relaxed,
              "\ntest triple p relaxed: ", test_triple_p_relaxed,
              " test triple r relaxed: ", test_triple_r_relaxed,
              " test triple f1 relaxed: ", test_triple_f1_relaxed)
Пример #3
0
    def run(self):
        print("Running on", self.a.device)
        self.set_device(self.a.device)

        np.random.seed(self.a.seed)
        torch.manual_seed(self.a.seed)
        torch.backends.cudnn.benchmark = True

        # create training set
        if self.a.test_ee:
            log('loading event extraction corpus from %s' % self.a.test_ee)

        WordsField = Field(lower=True, include_lengths=True, batch_first=True)
        PosTagsField = Field(lower=True, batch_first=True)
        EntityLabelsField = MultiTokenField(lower=False, batch_first=True)
        AdjMatrixField = SparseField(sequential=False,
                                     use_vocab=False,
                                     batch_first=True)
        LabelField = Field(lower=False,
                           batch_first=True,
                           pad_token='0',
                           unk_token=None)
        EventsField = EventField(lower=False, batch_first=True)
        EntitiesField = EntityField(lower=False,
                                    batch_first=True,
                                    use_vocab=False)
        SENTIDField = SparseField(sequential=False,
                                  use_vocab=False,
                                  batch_first=True)
        if self.a.amr:
            colcc = 'simple-parsing'
        else:
            colcc = 'combined-parsing'
        print(colcc)

        train_ee_set = ACE2005Dataset(
            path=self.a.train_ee,
            fields={
                "sentence_id": ("SENTID", SENTIDField),
                "words": ("WORDS", WordsField),
                "pos-tags": ("POSTAGS", PosTagsField),
                "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                colcc: ("ADJM", AdjMatrixField),
                "golden-event-mentions": ("LABEL", LabelField),
                "all-events": ("EVENT", EventsField),
                "all-entities": ("ENTITIES", EntitiesField)
            },
            amr=self.a.amr,
            keep_events=1)

        dev_ee_set = ACE2005Dataset(path=self.a.dev_ee,
                                    fields={
                                        "sentence_id": ("SENTID", SENTIDField),
                                        "words": ("WORDS", WordsField),
                                        "pos-tags": ("POSTAGS", PosTagsField),
                                        "golden-entity-mentions":
                                        ("ENTITYLABELS", EntityLabelsField),
                                        colcc: ("ADJM", AdjMatrixField),
                                        "golden-event-mentions":
                                        ("LABEL", LabelField),
                                        "all-events": ("EVENT", EventsField),
                                        "all-entities":
                                        ("ENTITIES", EntitiesField)
                                    },
                                    amr=self.a.amr,
                                    keep_events=0)

        test_ee_set = ACE2005Dataset(path=self.a.test_ee,
                                     fields={
                                         "sentence_id":
                                         ("SENTID", SENTIDField),
                                         "words": ("WORDS", WordsField),
                                         "pos-tags": ("POSTAGS", PosTagsField),
                                         "golden-entity-mentions":
                                         ("ENTITYLABELS", EntityLabelsField),
                                         colcc: ("ADJM", AdjMatrixField),
                                         "golden-event-mentions":
                                         ("LABEL", LabelField),
                                         "all-events": ("EVENT", EventsField),
                                         "all-entities":
                                         ("ENTITIES", EntitiesField)
                                     },
                                     amr=self.a.amr,
                                     keep_events=0)

        if self.a.load_grounding:
            ####################    loading grounding dataset   ####################
            if self.a.train_grounding:
                log('loading grounding corpus from %s' %
                    self.a.train_grounding)

            # only for grounding
            IMAGEIDField = SparseField(sequential=False,
                                       use_vocab=False,
                                       batch_first=True)
            # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True)

            transform = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225))
            ])

            train_grounding_set = GroundingDataset(
                path=self.a.train_grounding,
                img_dir=self.a.img_dir_grounding,
                fields={
                    "id": ("IMAGEID", IMAGEIDField),
                    "sentence_id": ("SENTID", SENTIDField),
                    "words": ("WORDS", WordsField),
                    "pos-tags": ("POSTAGS", PosTagsField),
                    "golden-entity-mentions":
                    ("ENTITYLABELS", EntityLabelsField),
                    colcc: ("ADJM", AdjMatrixField),
                    "all-entities": ("ENTITIES", EntitiesField),
                    # "image": ("IMAGE", IMAGEField),
                },
                transform=transform,
                amr=self.a.amr,
                load_object=self.a.add_object,
                object_ontology_file=self.a.object_class_map_file,
                object_detection_pkl_file=self.a.object_detection_pkl_file_g,
                object_detection_threshold=self.a.object_detection_threshold,
            )

            dev_grounding_set = GroundingDataset(
                path=self.a.dev_grounding,
                img_dir=self.a.img_dir_grounding,
                fields={
                    "id": ("IMAGEID", IMAGEIDField),
                    "sentence_id": ("SENTID", SENTIDField),
                    "words": ("WORDS", WordsField),
                    "pos-tags": ("POSTAGS", PosTagsField),
                    "golden-entity-mentions":
                    ("ENTITYLABELS", EntityLabelsField),
                    colcc: ("ADJM", AdjMatrixField),
                    "all-entities": ("ENTITIES", EntitiesField),
                    # "image": ("IMAGE", IMAGEField),
                },
                transform=transform,
                amr=self.a.amr,
                load_object=self.a.add_object,
                object_ontology_file=self.a.object_class_map_file,
                object_detection_pkl_file=self.a.object_detection_pkl_file_g,
                object_detection_threshold=self.a.object_detection_threshold,
            )

            # test_grounding_set = GroundingDataset(path=self.a.test_grounding,
            #                                       img_dir=self.a.img_dir_grounding,
            #                                       fields={"id": ("IMAGEID", IMAGEIDField),
            #                                               "sentence_id": ("SENTID", SENTIDField),
            #                                               "words": ("WORDS", WordsField),
            #                                               "pos-tags": ("POSTAGS", PosTagsField),
            #                                               "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
            #                                               colcc: ("ADJM", AdjMatrixField),
            #                                               "all-entities": ("ENTITIES", EntitiesField),
            #                                               # "image": ("IMAGE", IMAGEField),
            #                                               },
            #                                       transform=transform,
            #                                       amr=self.a.amr,
            #                                       load_object=self.a.add_object,
            #                                       object_ontology_file=self.a.object_class_map_file,
            #                                       object_detection_pkl_file=self.a.object_detection_pkl_file_g,
            #                                       object_detection_threshold=self.a.object_detection_threshold,
            #                                       )

            ####################    build vocabulary   ####################

            if self.a.webd:
                pretrained_embedding = Vectors(self.a.webd,
                                               ".",
                                               unk_init=partial(
                                                   torch.nn.init.uniform_,
                                                   a=-0.15,
                                                   b=0.15))
                WordsField.build_vocab(train_ee_set.WORDS,
                                       dev_ee_set.WORDS,
                                       train_grounding_set.WORDS,
                                       dev_grounding_set.WORDS,
                                       vectors=pretrained_embedding)
            else:
                WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS,
                                       train_grounding_set.WORDS,
                                       dev_grounding_set.WORDS)
            PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS,
                                     train_grounding_set.POSTAGS,
                                     dev_grounding_set.POSTAGS)
            EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS,
                                          dev_ee_set.ENTITYLABELS,
                                          train_grounding_set.ENTITYLABELS,
                                          dev_grounding_set.ENTITYLABELS)
        else:
            if self.a.webd:
                pretrained_embedding = Vectors(self.a.webd,
                                               ".",
                                               unk_init=partial(
                                                   torch.nn.init.uniform_,
                                                   a=-0.15,
                                                   b=0.15))
                WordsField.build_vocab(train_ee_set.WORDS,
                                       dev_ee_set.WORDS,
                                       vectors=pretrained_embedding)
            else:
                WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS)
            PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS)
            EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS,
                                          dev_ee_set.ENTITYLABELS)

        LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL)
        EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT)
        consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME]
        # print("O label is", consts.O_LABEL)
        consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME]
        # print("O label for AE is", consts.ROLE_O_LABEL)

        self.a.label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5
        self.a.label_weight[consts.O_LABEL] = 1.0
        self.a.arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5
        # add role mask
        self.a.role_mask = event_role_mask(self.a.train_ee, self.a.dev_ee,
                                           LabelField.vocab.stoi,
                                           EventsField.vocab.stoi, self.device)
        # print('self.a.hps', self.a.hps)
        if not self.a.hps_path:
            self.a.hps = eval(self.a.hps)
        if "wemb_size" not in self.a.hps:
            self.a.hps["wemb_size"] = len(WordsField.vocab.itos)
        if "pemb_size" not in self.a.hps:
            self.a.hps["pemb_size"] = len(PosTagsField.vocab.itos)
        if "psemb_size" not in self.a.hps:
            self.a.hps["psemb_size"] = max([
                train_ee_set.longest(),
                dev_ee_set.longest(),
                test_ee_set.longest()
            ]) + 2
        if "eemb_size" not in self.a.hps:
            self.a.hps["eemb_size"] = len(EntityLabelsField.vocab.itos)
        if "oc" not in self.a.hps:
            self.a.hps["oc"] = len(LabelField.vocab.itos)
        if "ae_oc" not in self.a.hps:
            self.a.hps["ae_oc"] = len(EventsField.vocab.itos)

        tester = self.get_tester(LabelField.vocab.itos, EventsField.vocab.itos)
        visualizer = EDVisualizer(self.a.gt_voa_text)

        if self.a.finetune:
            log('init model from ' + self.a.finetune)
            model = load_ee_model(self.a.hps, self.a.finetune,
                                  WordsField.vocab.vectors, self.device)
            log('model loaded, there are %i sets of params' %
                len(model.parameters_requires_grads()))
        else:
            model = load_ee_model(self.a.hps, None, WordsField.vocab.vectors,
                                  self.device)
            log('model created from scratch, there are %i sets of params' %
                len(model.parameters_requires_grads()))

        self.a.word_i2s = WordsField.vocab.itos
        self.a.label_i2s = LabelField.vocab.itos
        self.a.role_i2s = EventsField.vocab.itos
        writer = SummaryWriter(os.path.join(self.a.out, "exp"))
        self.a.writer = writer

        # train_iter = BucketIterator(train_ee_set, batch_size=self.a.batch,
        #                             train=True, shuffle=False, device=-1,
        #                             sort_key=lambda x: len(x.POSTAGS))
        # dev_iter = BucketIterator(dev_ee_set, batch_size=self.a.batch, train=False,
        #                           shuffle=False, device=-1,
        #                           sort_key=lambda x: len(x.POSTAGS))
        # test_iter = BucketIterator(test_ee_set, batch_size=self.a.batch, train=False,
        #                            shuffle=False, device=-1,
        #                            sort_key=lambda x: len(x.POSTAGS))
        test_m2e2_set = ACE2005Dataset(
            path=self.a.gt_voa_text,
            fields={
                "sentence_id": ("SENTID", SENTIDField),
                "words": ("WORDS", WordsField),
                "pos-tags": ("POSTAGS", PosTagsField),
                "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                colcc: ("ADJM", AdjMatrixField),
                "golden-event-mentions": ("LABEL", LabelField),
                "all-events": ("EVENT", EventsField),
                "all-entities": ("ENTITIES", EntitiesField)
            },
            amr=self.a.amr,
            keep_events=self.a.keep_events)
        # test_m2e2_set = M2E2Dataset(path=gt_voa_text,
        #                             img_dir=voa_image_dir,
        #                             fields={"image": ("IMAGEID", IMAGEIDField),
        #                                     "sentence_id": ("SENTID", SENTIDField),
        #                                     "words": ("WORDS", WordsField),
        #                                     "pos-tags": ("POSTAGS", PosTagsField),
        #                                     "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
        #                                     colcc: ("ADJM", AdjMatrixField),
        #                                     "all-entities": ("ENTITIES", EntitiesField),
        #                                     # "image": ("IMAGE", IMAGEField),
        #                                     "golden-event-mentions": ("LABEL", LabelField),
        #                                     "all-events": ("EVENT", EventsField),
        #                                     },
        #                             transform=transform,
        #                             amr=self.a.amr,
        #                             load_object=self.a.add_object,
        #                             object_ontology_file=self.a.object_class_map_file,
        #                             object_detection_pkl_file=self.a.object_detection_pkl_file,
        #                             object_detection_threshold=self.a.object_detection_threshold,
        #                             )
        test_m2e2_iter = BucketIterator(test_m2e2_set,
                                        batch_size=1,
                                        train=False,
                                        shuffle=False,
                                        device=-1,
                                        sort_key=lambda x: len(x.POSTAGS))

        print("\nStarting testing ...\n")

        # Testing Phrase
        test_loss, test_ed_p, test_ed_r, test_ed_f1, \
        test_ae_p, test_ae_r, test_ae_f1 = run_over_data(data_iter=test_m2e2_iter,
                                                         optimizer=None,
                                                         model=model,
                                                         need_backward=False,
                                                         MAX_STEP=len(test_m2e2_iter),
                                                         tester=tester,
                                                         visualizer=visualizer,
                                                         hyps=model.hyperparams,
                                                         device=model.device,
                                                         maxnorm=self.a.maxnorm,
                                                         word_i2s=self.a.word_i2s,
                                                         label_i2s=self.a.label_i2s,
                                                         role_i2s=self.a.role_i2s,
                                                         weight=self.a.label_weight,
                                                         arg_weight=self.a.arg_weight,
                                                         save_output=os.path.join(
                                                             self.a.out,
                                                             "test_final.txt"),
                                                         role_mask=self.a.role_mask)

        print("\nFinally test loss: ", test_loss, "\ntest ed p: ", test_ed_p,
              " test ed r: ", test_ed_r, " test ed f1: ", test_ed_f1,
              "\ntest ae p: ", test_ae_p, " test ae r: ", test_ae_r,
              " test ae f1: ", test_ae_f1)
Пример #4
0
    def run(self):
        print("Running on", self.a.device)
        self.set_device(self.a.device)

        np.random.seed(self.a.seed)
        torch.manual_seed(self.a.seed)
        torch.backends.cudnn.benchmark = True

        # build text event vocab and ee_role vocab
        WordsField = Field(lower=True, include_lengths=True, batch_first=True)
        PosTagsField = Field(lower=True, batch_first=True)
        EntityLabelsField = MultiTokenField(lower=False, batch_first=True)
        AdjMatrixField = SparseField(sequential=False,
                                     use_vocab=False,
                                     batch_first=True)
        EntitiesField = EntityField(lower=False,
                                    batch_first=True,
                                    use_vocab=False)
        # only for ee
        LabelField = Field(lower=False,
                           batch_first=True,
                           pad_token='0',
                           unk_token=None)
        EventsField = EventField(lower=False, batch_first=True)
        colcc = 'stanford-colcc'
        train_ee_set = ACE2005Dataset(path=self.a.train_ee,
                                      fields={
                                          "words": ("WORDS", WordsField),
                                          "pos-tags":
                                          ("POSTAGS", PosTagsField),
                                          "golden-entity-mentions":
                                          ("ENTITYLABELS", EntityLabelsField),
                                          colcc: ("ADJM", AdjMatrixField),
                                          "golden-event-mentions":
                                          ("LABEL", LabelField),
                                          "all-events": ("EVENT", EventsField),
                                          "all-entities":
                                          ("ENTITIES", EntitiesField)
                                      },
                                      amr=False,
                                      keep_events=1)
        pretrained_embedding = Vectors(self.a.webd,
                                       ".",
                                       unk_init=partial(torch.nn.init.uniform_,
                                                        a=-0.15,
                                                        b=0.15))
        LabelField.build_vocab(train_ee_set.LABEL,
                               vectors=pretrained_embedding)
        EventsField.build_vocab(train_ee_set.EVENT,
                                vectors=pretrained_embedding)

        # consts.O_LABEL = LabelField.vocab.stoi["O"]
        # # print("O label is", consts.O_LABEL)
        # consts.ROLE_O_LABEL = EventsField.vocab.stoi["OTHER"]
        # print("O label for AE is", consts.ROLE_O_LABEL)

        # create training set
        if self.a.train_sr:
            log('loading corpus from %s' % self.a.train_sr)

        train_sr_set = ImSituDataset(
            self.a.image_dir,
            self.vocab_noun,
            self.vocab_role,
            self.vocab_verb,
            LabelField.vocab.stoi,
            EventsField.vocab.stoi,
            self.a.imsitu_ontology_file,
            self.a.train_sr,
            self.a.verb_mapping_file,
            self.a.object_class_map_file,
            self.a.object_detection_pkl_file,
            self.a.object_detection_threshold,
            self.transform,
            filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
            load_object=self.a.add_object,
            filter_place=self.a.filter_place)
        dev_sr_set = ImSituDataset(
            self.a.image_dir,
            self.vocab_noun,
            self.vocab_role,
            self.vocab_verb,
            LabelField.vocab.stoi,
            EventsField.vocab.stoi,
            self.a.imsitu_ontology_file,
            self.a.dev_sr,
            self.a.verb_mapping_file,
            self.a.object_class_map_file,
            self.a.object_detection_pkl_file,
            self.a.object_detection_threshold,
            self.transform,
            filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
            load_object=self.a.add_object,
            filter_place=self.a.filter_place)
        test_sr_set = ImSituDataset(
            self.a.image_dir,
            self.vocab_noun,
            self.vocab_role,
            self.vocab_verb,
            LabelField.vocab.stoi,
            EventsField.vocab.stoi,
            self.a.imsitu_ontology_file,
            self.a.test_sr,
            self.a.verb_mapping_file,
            self.a.object_class_map_file,
            self.a.object_detection_pkl_file,
            self.a.object_detection_threshold,
            self.transform,
            filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
            load_object=self.a.add_object,
            filter_place=self.a.filter_place)

        embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to(
            self.device)
        embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to(
            self.device)
        embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to(
            self.device)
        # consts.O_LABEL = self.vocab_verb.word2id['0'] # verb??
        # consts.ROLE_O_LABEL = self.vocab_role.word2id["OTHER"] #???

        # self.a.label_weight = torch.ones([len(vocab_sr.id2word)]) * 5 # more important to learn
        # self.a.label_weight[consts.O_LABEL] = 1.0 #???

        self.a.hps = eval(self.a.hps)
        if self.a.textontology:
            if "wvemb_size" not in self.a.hps:
                self.a.hps["wvemb_size"] = len(LabelField.vocab.stoi)
            if "wremb_size" not in self.a.hps:
                self.a.hps["wremb_size"] = len(EventsField.vocab.itos)
            if "wnemb_size" not in self.a.hps:
                self.a.hps["wnemb_size"] = len(self.vocab_noun.id2word)
            if "oc" not in self.a.hps:
                self.a.hps["oc"] = len(LabelField.vocab.itos)
            if "ae_oc" not in self.a.hps:
                self.a.hps["ae_oc"] = len(EventsField.vocab.itos)
        else:
            if "wvemb_size" not in self.a.hps:
                self.a.hps["wvemb_size"] = len(self.vocab_verb.id2word)
            if "wremb_size" not in self.a.hps:
                self.a.hps["wremb_size"] = len(self.vocab_role.id2word)
            if "wnemb_size" not in self.a.hps:
                self.a.hps["wnemb_size"] = len(self.vocab_noun.id2word)
            if "oc" not in self.a.hps:
                self.a.hps["oc"] = len(LabelField.vocab.itos)
            if "ae_oc" not in self.a.hps:
                self.a.hps["ae_oc"] = len(EventsField.vocab.itos)

        tester = self.get_tester()

        if self.a.textontology:
            if self.a.finetune:
                log('init model from ' + self.a.finetune)
                model = load_sr_model(self.a.hps,
                                      embeddingMatrix_noun,
                                      LabelField.vocab.vectors,
                                      EventsField.vocab.vectors,
                                      self.a.finetune,
                                      self.device,
                                      add_object=self.a.add_object)
                log('sr model loaded, there are %i sets of params' %
                    len(model.parameters_requires_grads()))
            else:
                model = load_sr_model(self.a.hps,
                                      embeddingMatrix_noun,
                                      LabelField.vocab.vectors,
                                      EventsField.vocab.vectors,
                                      None,
                                      self.device,
                                      add_object=self.a.add_object)
                log('sr model created from scratch, there are %i sets of params'
                    % len(model.parameters_requires_grads()))
        else:
            if self.a.finetune:
                log('init model from ' + self.a.finetune)
                model = load_sr_model(self.a.hps,
                                      embeddingMatrix_noun,
                                      embeddingMatrix_verb,
                                      embeddingMatrix_role,
                                      self.a.finetune,
                                      self.device,
                                      add_object=self.a.add_object)
                log('sr model loaded, there are %i sets of params' %
                    len(model.parameters_requires_grads()))
            else:
                model = load_sr_model(self.a.hps,
                                      embeddingMatrix_noun,
                                      embeddingMatrix_verb,
                                      embeddingMatrix_role,
                                      None,
                                      self.device,
                                      add_object=self.a.add_object)
                log('sr model created from scratch, there are %i sets of params'
                    % len(model.parameters_requires_grads()))

        if self.a.optimizer == "adadelta":
            optimizer_constructor = partial(
                torch.optim.Adadelta,
                params=model.parameters_requires_grads(),
                weight_decay=self.a.l2decay)
        elif self.a.optimizer == "adam":
            optimizer_constructor = partial(
                torch.optim.Adam,
                params=model.parameters_requires_grads(),
                weight_decay=self.a.l2decay)
        else:
            optimizer_constructor = partial(
                torch.optim.SGD,
                params=model.parameters_requires_grads(),
                weight_decay=self.a.l2decay,
                momentum=0.9)

        # for name, para in model.named_parameters():
        #     if para.requires_grad:
        #         print(name)
        # exit(1)

        log('optimizer in use: %s' % str(self.a.optimizer))

        log('init complete\n')

        if not os.path.exists(self.a.out):
            os.mkdir(self.a.out)

        self.a.word_i2s = self.vocab_noun.id2word
        # if self.a.textontology:
        self.a.acelabel_i2s = LabelField.vocab.itos
        self.a.acerole_i2s = EventsField.vocab.itos
        with open(os.path.join(self.a.out, "label_s2i.vec"), "wb") as f:
            pickle.dump(LabelField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "role_s2i.vec"), "wb") as f:
            pickle.dump(EventsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "label_i2s.vec"), "wb") as f:
            pickle.dump(LabelField.vocab.itos, f)
        with open(os.path.join(self.a.out, "role_i2s.vec"), "wb") as f:
            pickle.dump(EventsField.vocab.itos, f)
        # else:
        self.a.label_i2s = self.vocab_verb.id2word  #LabelField.vocab.itos
        self.a.role_i2s = self.vocab_role.id2word
        # save as Vocab
        writer = SummaryWriter(os.path.join(self.a.out, "exp"))
        self.a.writer = writer

        with open(os.path.join(self.a.out, "sr_hyps.json"), "w") as f:
            json.dump(self.a.hps, f)

        sr_train(
            model=model,
            train_set=train_sr_set,
            dev_set=dev_sr_set,
            test_set=test_sr_set,
            optimizer_constructor=optimizer_constructor,
            epochs=self.a.epochs,
            tester=tester,
            parser=self.a,
            other_testsets={
                # "dev 1/1":  dev_sr_loader,
                # "test 1/1": test_sr_loader,
            })
        log('Done!')
Пример #5
0
    def run(self):
        print("Running on", self.a.device)
        self.set_device(self.a.device)

        np.random.seed(self.a.seed)
        torch.manual_seed(self.a.seed)
        torch.backends.cudnn.benchmark = True

        ####################    loading event extraction dataset   ####################
        if self.a.test_ee:
            log('testing event extraction corpus from %s' % self.a.test_ee)
        if self.a.test_ee:
            log('testing event extraction corpus from %s' % self.a.test_ee)

        # both for grounding and ee
        WordsField = Field(lower=True, include_lengths=True, batch_first=True)
        PosTagsField = Field(lower=True, batch_first=True)
        EntityLabelsField = MultiTokenField(lower=False, batch_first=True)
        AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True)
        EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False)
        # only for ee
        LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None)
        EventsField = EventField(lower=False, batch_first=True)
        SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True)

        if self.a.amr:
            colcc = 'simple-parsing'
        else:
            colcc = 'combined-parsing'
        print(colcc)

        train_ee_set = ACE2005Dataset(path=self.a.train_ee,
                                   fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
                                           "pos-tags": ("POSTAGS", PosTagsField),
                                           "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                           colcc: ("ADJM", AdjMatrixField),
                                           "golden-event-mentions": ("LABEL", LabelField),
                                           "all-events": ("EVENT", EventsField),
                                           "all-entities": ("ENTITIES", EntitiesField)},
                                   amr=self.a.amr, keep_events=1)

        dev_ee_set = ACE2005Dataset(path=self.a.dev_ee,
                                 fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
                                         "pos-tags": ("POSTAGS", PosTagsField),
                                         "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                         colcc: ("ADJM", AdjMatrixField),
                                         "golden-event-mentions": ("LABEL", LabelField),
                                         "all-events": ("EVENT", EventsField),
                                         "all-entities": ("ENTITIES", EntitiesField)},
                                 amr=self.a.amr, keep_events=0)

        # test_ee_set = ACE2005Dataset(path=self.a.test_ee,
        #                           fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
        #                                   "pos-tags": ("POSTAGS", PosTagsField),
        #                                   "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
        #                                   colcc: ("ADJM", AdjMatrixField),
        #                                   "golden-event-mentions": ("LABEL", LabelField),
        #                                   "all-events": ("EVENT", EventsField),
        #                                   "all-entities": ("ENTITIES", EntitiesField)},
        #                           amr=self.a.amr, keep_events=0)

        print('self.a.train_ee', self.a.train_ee)
        LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL)
        print('LabelField.vocab.stoi', LabelField.vocab.stoi)
        EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT)
        print('EventsField.vocab.stoi', EventsField.vocab.stoi)
        print('len(EventsField.vocab.itos)', len(EventsField.vocab.itos))
        print('len(EventsField.vocab.stoi)', len(EventsField.vocab.stoi))

        ####################    loading SR dataset   ####################
        # both for grounding and sr
        if self.a.train_sr:
            log('loading corpus from %s' % self.a.train_sr)

        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])

        vocab_noun = Vocab(os.path.join(self.a.vocab, 'vocab_situation_noun.pkl'), load=True)
        vocab_role = Vocab(os.path.join(self.a.vocab, 'vocab_situation_role.pkl'), load=True)
        vocab_verb = Vocab(os.path.join(self.a.vocab, 'vocab_situation_verb.pkl'), load=True)

        # only need get_role_mask() and sr_mapping()
        train_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb,
                                     EventsField.vocab.stoi, LabelField.vocab.stoi,
                                     self.a.imsitu_ontology_file,
                                     self.a.train_sr, self.a.verb_mapping_file,
                                     None, None,
                                     0,
                                     transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
                                     load_object=False, filter_place=self.a.filter_place)


        ####################    loading grounding dataset   ####################
        if self.a.train_grounding:
            log('loading grounding corpus from %s' % self.a.train_grounding)

        # only for grounding
        IMAGEIDField = SparseField(sequential=False, use_vocab=False, batch_first=True)
        SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True)
        # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True)

        train_grounding_set = GroundingDataset(path=self.a.train_grounding,
                                               img_dir=None,
                                               fields={"id": ("IMAGEID", IMAGEIDField),
                                                       "sentence_id": ("SENTID", SENTIDField),
                                                       "words": ("WORDS", WordsField),
                                                       "pos-tags": ("POSTAGS", PosTagsField),
                                                       "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                                       colcc: ("ADJM", AdjMatrixField),
                                                       "all-entities": ("ENTITIES", EntitiesField),
                                                       # "image": ("IMAGE", IMAGEField),
                                                       },
                                               transform=transform,
                                               amr=self.a.amr)

        dev_grounding_set = GroundingDataset(path=self.a.dev_grounding,
                                             img_dir=None,
                                             fields={"id": ("IMAGEID", IMAGEIDField),
                                                     "sentence_id": ("SENTID", SENTIDField),
                                                     "words": ("WORDS", WordsField),
                                                     "pos-tags": ("POSTAGS", PosTagsField),
                                                     "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                                     colcc: ("ADJM", AdjMatrixField),
                                                     "all-entities": ("ENTITIES", EntitiesField),
                                                     # "image": ("IMAGE", IMAGEField),
                                                     },
                                             transform=transform,
                                             amr=self.a.amr)

        # test_grounding_set = GroundingDataset(path=self.a.test_grounding,
        #                                       img_dir=None,
        #                                       fields={"id": ("IMAGEID", IMAGEIDField),
        #                                               "sentence_id": ("SENTID", SENTIDField),
        #                                               "words": ("WORDS", WordsField),
        #                                               "pos-tags": ("POSTAGS", PosTagsField),
        #                                               "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
        #                                               colcc: ("ADJM", AdjMatrixField),
        #                                               "all-entities": ("ENTITIES", EntitiesField),
        #                                               # "image": ("IMAGE", IMAGEField),
        #                                               },
        #                                       transform=transform,
        #                                       amr=self.a.amr)

        ####################    build vocabulary   ####################

        if self.a.webd:
            pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15))
            WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS, vectors=pretrained_embedding)
        else:
            WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS)
        PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS, train_grounding_set.POSTAGS, dev_grounding_set.POSTAGS)
        EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS,  train_grounding_set.ENTITYLABELS, dev_grounding_set.ENTITYLABELS)

        consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME]
        # print("O label is", consts.O_LABEL)
        consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME]
        # print("O label for AE is", consts.ROLE_O_LABEL)

        # dev_ee_set1 = ACE2005Dataset(path=self.a.dev_ee,
        #                           fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
        #                                   "pos-tags": ("POSTAGS", PosTagsField),
        #                                   "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
        #                                   colcc: ("ADJM", AdjMatrixField),
        #                                   "golden-event-mentions": ("LABEL", LabelField),
        #                                   "all-events": ("EVENT", EventsField),
        #                                   "all-entities": ("ENTITIES", EntitiesField)},
        #                           amr=self.a.amr, keep_events=1, only_keep=True)
        #
        # test_ee_set1 = ACE2005Dataset(path=self.a.test_ee,
        #                            fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
        #                                    "pos-tags": ("POSTAGS", PosTagsField),
        #                                    "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
        #                                    colcc: ("ADJM", AdjMatrixField),
        #                                    "golden-event-mentions": ("LABEL", LabelField),
        #                                    "all-events": ("EVENT", EventsField),
        #                                    "all-entities": ("ENTITIES", EntitiesField)},
        #                            amr=self.a.amr, keep_events=1, only_keep=True)
        # print("train set length", len(train_ee_set))
        #
        # print("dev set length", len(dev_ee_set))
        # print("dev set 1/1 length", len(dev_ee_set1))
        #
        # print("test set length", len(test_ee_set))
        # print("test set 1/1 length", len(test_ee_set1))

        # sr model initialization
        if not self.a.sr_hps_path:
            self.a.sr_hps = eval(self.a.sr_hps)
        embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to(self.device)
        embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to(self.device)
        embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to(self.device)
        if "wvemb_size" not in self.a.sr_hps:
            self.a.sr_hps["wvemb_size"] = len(vocab_verb.id2word)
        if "wremb_size" not in self.a.sr_hps:
            self.a.sr_hps["wremb_size"] = len(vocab_role.id2word)
        if "wnemb_size" not in self.a.sr_hps:
            self.a.sr_hps["wnemb_size"] = len(vocab_noun.id2word)
        # if "ae_oc" not in self.a.sr_hps:
        #     self.a.sr_hps["ae_oc"] = len(vocab_role.id2word)

        # self.a.ee_label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5
        # self.a.ee_label_weight[consts.O_LABEL] = 1.0
        # self.a.ee_arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5
        # if not self.a.ee_hps_path:
        #     self.a.ee_hps = eval(self.a.ee_hps)
        # if "wemb_size" not in self.a.ee_hps:
        #     self.a.ee_hps["wemb_size"] = len(WordsField.vocab.itos)
        # if "pemb_size" not in self.a.ee_hps:
        #     self.a.ee_hps["pemb_size"] = len(PosTagsField.vocab.itos)
        # if "psemb_size" not in self.a.ee_hps:
        #     # self.a.ee_hps["psemb_size"] = max([train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2
        #     self.a.ee_hps["psemb_size"] = max([train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest(), train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2
        # if "eemb_size" not in self.a.ee_hps:
        #     self.a.ee_hps["eemb_size"] = len(EntityLabelsField.vocab.itos)
        # if "oc" not in self.a.ee_hps:
        #     self.a.ee_hps["oc"] = len(LabelField.vocab.itos)
        # if "ae_oc" not in self.a.ee_hps:
        #     self.a.ee_hps["ae_oc"] = len(EventsField.vocab.itos)
        if "oc" not in self.a.sr_hps:
            self.a.sr_hps["oc"] = len(LabelField.vocab.itos)
        if "ae_oc" not in self.a.sr_hps:
            self.a.sr_hps["ae_oc"] = len(EventsField.vocab.itos)



        ace_classifier = ACEClassifier(self.a.sr_hps["wemb_dim"], self.a.sr_hps["oc"], self.a.sr_hps["ae_oc"],
                                       self.device)

        ee_model = None
        # # if self.a.score_ee:
        # if  self.a.finetune_ee:
        #     log('init ee model from ' + self.a.finetune_ee)
        #     ee_model = load_ee_model(self.a.ee_hps, self.a.finetune_ee, WordsField.vocab.vectors, self.device, ace_classifier)
        #     log('ee model loaded, there are %i sets of params' % len(ee_model.parameters_requires_grads()))
        # else:
        #     ee_model = load_ee_model(self.a.ee_hps, None, WordsField.vocab.vectors, self.device, ace_classifier)
        #     log('ee model created from scratch, there are %i sets of params' % len(ee_model.parameters_requires_grads()))

        # if self.a.score_sr:
        if self.a.finetune_sr:
            log('init sr model from ' + self.a.finetune_sr)
            sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, self.a.finetune_sr, self.device, ace_classifier, add_object=self.a.add_object)
            log('sr model loaded, there are %i sets of params' % len(sr_model.parameters_requires_grads()))
        else:
            sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, None, self.device, ace_classifier, add_object=self.a.add_object)
            log('sr model created from scratch, there are %i sets of params' % len(sr_model.parameters_requires_grads()))

        model = GroundingModel(ee_model, sr_model, self.get_device())
        # ee_model = torch.nn.DataParallel(ee_model)
        # sr_model = torch.nn.DataParallel(sr_model)
        # model = torch.nn.DataParallel(model)

        # if self.a.optimizer == "adadelta":
        #     optimizer_constructor = partial(torch.optim.Adadelta, params=model.parameters_requires_grads(),
        #                                     weight_decay=self.a.l2decay)
        # elif self.a.optimizer == "adam":
        #     optimizer_constructor = partial(torch.optim.Adam, params=model.parameters_requires_grads(),
        #                                     weight_decay=self.a.l2decay)
        # else:
        #     optimizer_constructor = partial(torch.optim.SGD, params=model.parameters_requires_grads(),
        #                                     weight_decay=self.a.l2decay,
        #                                     momentum=0.9)

        # log('optimizer in use: %s' % str(self.a.optimizer))

        if not os.path.exists(self.a.out):
            os.mkdir(self.a.out)
        # with open(os.path.join(self.a.out, "word.vec"), "wb") as f:
        #     pickle.dump(WordsField.vocab, f)
        # with open(os.path.join(self.a.out, "pos.vec"), "wb") as f:
        #     pickle.dump(PosTagsField.vocab.stoi, f)
        # with open(os.path.join(self.a.out, "entity.vec"), "wb") as f:
        #     pickle.dump(EntityLabelsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "label.vec"), "wb") as f:
            pickle.dump(LabelField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "role.vec"), "wb") as f:
            pickle.dump(EventsField.vocab.stoi, f)

        log('init complete\n')

        # # ee mappings
        # self.a.ee_word_i2s = WordsField.vocab.itos
        # self.a.ee_label_i2s = LabelField.vocab.itos
        # self.a.ee_role_i2s = EventsField.vocab.itos
        # self.a.ee_role_mask = None
        # if self.a.apply_ee_role_mask:
        #     self.a.ee_role_mask = event_role_mask(self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi, EventsField.vocab.stoi, self.device)
        # sr mappings
        self.a.sr_word_i2s = vocab_noun.id2word
        self.a.sr_label_i2s = vocab_verb.id2word  # LabelField.vocab.itos
        self.a.sr_role_i2s = vocab_role.id2word
        self.a.role_masks = train_sr_set.get_role_mask().to_dense().to(self.device)
        writer = SummaryWriter(os.path.join(self.a.out, "exp"))
        self.a.writer = writer

        # loading testing data
        # voa_text = self.a.test_voa_text
        voa_image_dir = self.a.test_voa_image
        gt_voa_image = self.a.gt_voa_image
        gt_voa_text = self.a.gt_voa_text
        gt_voa_align =self.a.gt_voa_align

        sr_verb_mapping, sr_role_mapping = train_sr_set.get_sr_mapping()

        test_m2e2_set = M2E2Dataset(path=gt_voa_text,
                                    img_dir=voa_image_dir,
                                    fields={"image": ("IMAGEID", IMAGEIDField),
                                          "sentence_id": ("SENTID", SENTIDField),
                                          "words": ("WORDS", WordsField),
                                          "pos-tags": ("POSTAGS", PosTagsField),
                                          "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                          colcc: ("ADJM", AdjMatrixField),
                                          "all-entities": ("ENTITIES", EntitiesField),
                                          # "image": ("IMAGE", IMAGEField),
                                          "golden-event-mentions": ("LABEL", LabelField),
                                          "all-events": ("EVENT", EventsField),
                                          },
                                    transform=transform,
                                    amr=self.a.amr,
                                    load_object=self.a.add_object,
                                    object_ontology_file=self.a.object_class_map_file,
                                    object_detection_pkl_file=self.a.object_detection_pkl_file,
                                    object_detection_threshold=self.a.object_detection_threshold,
                                    keep_events=self.a.keep_events,
                                    )

        object_results, object_label, object_detection_threshold = test_m2e2_set.get_object_results()

        # build batch on cpu
        test_m2e2_iter = BucketIterator(test_m2e2_set, batch_size=1, train=False,
                                   shuffle=False, device=-1,
                                   sort_key=lambda x: len(x.POSTAGS))

        # scores = 0.0
        # now_bad = 0
        # restart_used = 0
        print("\nStarting testing...\n")
        # lr = parser.lr
        # optimizer = optimizer_constructor(lr=lr)

        # ee_tester = EDTester(LabelField.vocab.itos, EventsField.vocab.itos, self.a.ignore_time_test)
        # sr_tester = SRTester()
        # g_tester = GroundingTester()
        j_tester = JointTester(self.a.ignore_place_sr_test, self.a.ignore_time_test)
        # if self.a.visual_voa_ee_path is not None:
        #     ee_visualizer = EDVisualizer(self.a.gt_voa_text)
        # else:
        #     ee_visualizer = None
        image_gt = json.load(open(gt_voa_image))

        # all_y = []
        # all_y_ = []
        # all_events = []
        # all_events_ = []

        vision_result = dict()
        # if self.a.visual_voa_g_path is not None and not os.path.exists(self.a.visual_voa_g_path):
        #     os.makedirs(self.a.visual_voa_g_path, exist_ok=True)
        # if self.a.visual_voa_ee_path is not None and not os.path.exists(self.a.visual_voa_ee_path):
        #     os.makedirs(self.a.visual_voa_ee_path, exist_ok=True)
        if self.a.visual_voa_sr_path is not None and not os.path.exists(self.a.visual_voa_sr_path):
            os.makedirs(self.a.visual_voa_sr_path, exist_ok=True)
        # grounding_writer = open(self.a.visual_voa_g_path, 'w')
        doc_done = set()
        with torch.no_grad():
            model.eval()
            for batch in test_m2e2_iter:
                vision_result = joint_test_batch(
                    model_g=model,
                    batch_g=batch,
                    device=self.device,
                    transform=transform,
                    img_dir=voa_image_dir,
                    # ee_hyps=self.a.ee_hps,
                    # ee_word_i2s=self.a.ee_word_i2s,
                    # ee_label_i2s=self.a.ee_label_i2s,
                    # ee_role_i2s=self.a.ee_role_i2s,
                    # ee_tester=ee_tester,
                    # ee_visualizer=ee_visualizer,
                    sr_noun_i2s=self.a.sr_word_i2s,
                    sr_verb_i2s=self.a.sr_label_i2s,
                    sr_role_i2s=self.a.sr_role_i2s,
                    # sr_tester=sr_tester,
                    role_masks=self.a.role_masks,
                    # ee_role_mask=self.a.ee_role_mask,
                    # j_tester=j_tester,
                    image_gt=image_gt,
                    verb2type=sr_verb_mapping,
                    role2role=sr_role_mapping,
                    vision_result=vision_result,
                    # all_y=all_y,
                    # all_y_=all_y_,
                    # all_events=all_events,
                    # all_events_=all_events_,
                    # visual_g_path=self.a.visual_voa_g_path,
                    # visual_ee_path=self.a.visual_voa_ee_path,
                    load_object=self.a.add_object,
                    object_results=object_results,
                    object_label=object_label,
                    object_detection_threshold=object_detection_threshold,
                    vocab_objlabel=vocab_noun.word2id,
                    # apply_ee_role_mask=self.a.apply_ee_role_mask
                    keep_events_sr=self.a.keep_events_sr,
                    doc_done=doc_done,
                )

        print('vision_result size', len(vision_result))
        # pickle.dump(vision_result, open(os.path.join(self.a.out, 'vision_result.pkl'), 'w'))

        # ep, er, ef = ee_tester.calculate_report(all_y, all_y_, transform=True)
        # ap, ar, af = ee_tester.calculate_sets(all_events, all_events_)
        # if self.a.visual_voa_ee_path is not None:
        #     ee_visualizer.rewrite_brat(self.a.visual_voa_ee_path, self.a.visual_voa_ee_gt_ann)
        #
        # print('text ep, er, ef', ep, er, ef)
        # print('text ap, ar, af', ap, ar, af)

        evt_p, evt_r, evt_f1, role_scores = j_tester.calculate_report(
            vision_result, voa_image_dir, self.a.visual_voa_sr_path, self.a.add_object,
            keep_events_sr=self.a.keep_events_sr
        )#consts.O_LABEL, consts.ROLE_O_LABEL)

        print('image event ep, er, ef \n', evt_p, '\n', evt_r, '\n', evt_f1)
        # if not self.a.add_object:
        #     print('image att_iou ap, ar, af', role_scores['role_att_iou_p'], role_scores['role_att_iou_r'],
        #           role_scores['role_att_iou_f1'])
        #     print('image att_hit ap, ar, af', role_scores['role_att_hit_p'], role_scores['role_att_hit_r'],
        #           role_scores['role_att_hit_f1'])
        #     print('image att_cor ap, ar, af', role_scores['role_att_cor_p'], role_scores['role_att_cor_r'],
        #           role_scores['role_att_cor_f1'])
        # else:
        #     print('image obj_iou ap, ar, af', role_scores['role_obj_iou_p'], role_scores['role_obj_iou_r'],
        #           role_scores['role_obj_iou_f1'])
        #     print('image obj_iou_union ap, ar, af', role_scores['role_obj_iou_union_p'], role_scores['role_obj_iou_union_r'],
        #           role_scores['role_obj_iou_union_f1'])
        for key in role_scores:
            print(key)
            for key_ in role_scores[key]:
                print(key_, role_scores[key][key_])
Пример #6
0
    def run(self):
        print("Running on", self.a.device)
        self.set_device(self.a.device)

        np.random.seed(self.a.seed)
        torch.manual_seed(self.a.seed)
        torch.backends.cudnn.benchmark = True

        # create training set
        if self.a.test_ee:
            log('loading event extraction corpus from %s' % self.a.test_ee)

        WordsField = Field(lower=True, include_lengths=True, batch_first=True)
        PosTagsField = Field(lower=True, batch_first=True)
        EntityLabelsField = MultiTokenField(lower=False, batch_first=True)
        AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True)
        LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None)
        EventsField = EventField(lower=False, batch_first=True)
        EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False)
        if self.a.amr:
            colcc = 'amr-colcc'
        else:
            colcc = 'stanford-colcc'
        print(colcc)

        train_ee_set = ACE2005Dataset(path=self.a.train_ee,
                                      fields={"words": ("WORDS", WordsField),
                                              "pos-tags": ("POSTAGS", PosTagsField),
                                              "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                              colcc: ("ADJM", AdjMatrixField),
                                              "golden-event-mentions": ("LABEL", LabelField),
                                              "all-events": ("EVENT", EventsField),
                                              "all-entities": ("ENTITIES", EntitiesField)},
                                      amr=self.a.amr, keep_events=1)

        dev_ee_set = ACE2005Dataset(path=self.a.dev_ee,
                                    fields={"words": ("WORDS", WordsField),
                                            "pos-tags": ("POSTAGS", PosTagsField),
                                            "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                            colcc: ("ADJM", AdjMatrixField),
                                            "golden-event-mentions": ("LABEL", LabelField),
                                            "all-events": ("EVENT", EventsField),
                                            "all-entities": ("ENTITIES", EntitiesField)},
                                    amr=self.a.amr, keep_events=0)

        test_ee_set = ACE2005Dataset(path=self.a.test_ee,
                                     fields={"words": ("WORDS", WordsField),
                                             "pos-tags": ("POSTAGS", PosTagsField),
                                             "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                             colcc: ("ADJM", AdjMatrixField),
                                             "golden-event-mentions": ("LABEL", LabelField),
                                             "all-events": ("EVENT", EventsField),
                                             "all-entities": ("ENTITIES", EntitiesField)},
                                     amr=self.a.amr, keep_events=0)

        if self.a.webd:
            pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15))
            WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, vectors=pretrained_embedding)
        else:
            WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS)
        PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS)
        EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS)
        LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL)
        EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT)
        consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME]
        # print("O label is", consts.O_LABEL)
        consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME]
        # print("O label for AE is", consts.ROLE_O_LABEL)

        self.a.label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5
        self.a.label_weight[consts.O_LABEL] = 1.0
        self.a.arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5
        # add role mask
        self.a.role_mask = event_role_mask(self.a.test_ee, self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi,
                                           EventsField.vocab.stoi, self.device)
        # print('self.a.hps', self.a.hps)
        if not self.a.hps_path:
            self.a.hps = eval(self.a.hps)
        if "wemb_size" not in self.a.hps:
            self.a.hps["wemb_size"] = len(WordsField.vocab.itos)
        if "pemb_size" not in self.a.hps:
            self.a.hps["pemb_size"] = len(PosTagsField.vocab.itos)
        if "psemb_size" not in self.a.hps:
            self.a.hps["psemb_size"] = max([train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest()]) + 2
        if "eemb_size" not in self.a.hps:
            self.a.hps["eemb_size"] = len(EntityLabelsField.vocab.itos)
        if "oc" not in self.a.hps:
            self.a.hps["oc"] = len(LabelField.vocab.itos)
        if "ae_oc" not in self.a.hps:
            self.a.hps["ae_oc"] = len(EventsField.vocab.itos)

        tester = self.get_tester(LabelField.vocab.itos, EventsField.vocab.itos)

        if self.a.finetune:
            log('init model from ' + self.a.finetune)
            model = load_ee_model(self.a.hps, self.a.finetune, WordsField.vocab.vectors, self.device)
            log('model loaded, there are %i sets of params' % len(model.parameters_requires_grads()))
        else:
            model = load_ee_model(self.a.hps, None, WordsField.vocab.vectors, self.device)
            log('model created from scratch, there are %i sets of params' % len(model.parameters_requires_grads()))

        self.a.word_i2s = WordsField.vocab.itos
        self.a.label_i2s = LabelField.vocab.itos
        self.a.role_i2s = EventsField.vocab.itos
        writer = SummaryWriter(os.path.join(self.a.out, "exp"))
        self.a.writer = writer

        # train_iter = BucketIterator(train_ee_set, batch_size=self.a.batch,
        #                             train=True, shuffle=False, device=-1,
        #                             sort_key=lambda x: len(x.POSTAGS))
        # dev_iter = BucketIterator(dev_ee_set, batch_size=self.a.batch, train=False,
        #                           shuffle=False, device=-1,
        #                           sort_key=lambda x: len(x.POSTAGS))
        test_iter = BucketIterator(test_ee_set, batch_size=self.a.batch, train=False,
                                   shuffle=False, device=-1,
                                   sort_key=lambda x: len(x.POSTAGS))

        print("\nStarting testing ...\n")

        # Testing Phrase
        test_loss, test_ed_p, test_ed_r, test_ed_f1, \
        test_ae_p, test_ae_r, test_ae_f1 = run_over_data(data_iter=test_iter,
                                                         optimizer=None,
                                                         model=model,
                                                         need_backward=False,
                                                         MAX_STEP=ceil(len(
                                                             test_ee_set) /
                                                                       self.a.batch),
                                                         tester=tester,
                                                         hyps=model.hyperparams,
                                                         device=model.device,
                                                         maxnorm=self.a.maxnorm,
                                                         word_i2s=self.a.word_i2s,
                                                         label_i2s=self.a.label_i2s,
                                                         role_i2s=self.a.role_i2s,
                                                         weight=self.a.label_weight,
                                                         arg_weight=self.a.arg_weight,
                                                         save_output=os.path.join(
                                                             self.a.out,
                                                             "test_final.txt"),
                                                         role_mask=self.a.role_mask)

        print("\nFinally test loss: ", test_loss,
              "\ntest ed p: ", test_ed_p,
              " test ed r: ", test_ed_r,
              " test ed f1: ", test_ed_f1,
              "\ntest ae p: ", test_ae_p,
              " test ae r: ", test_ae_r,
               " test ae f1: ", test_ae_f1)
Пример #7
0
    def run(self):
        print("Running on", self.a.device)
        self.set_device(self.a.device)

        np.random.seed(self.a.seed)
        torch.manual_seed(self.a.seed)
        torch.backends.cudnn.benchmark = True

        ####################    loading event extraction dataset   ####################
        if self.a.train_ee:
            log('loading event extraction corpus from %s' % self.a.train_ee)

        # both for grounding and ee
        WordsField = Field(lower=True, include_lengths=True, batch_first=True)
        PosTagsField = Field(lower=True, batch_first=True)
        EntityLabelsField = MultiTokenField(lower=False, batch_first=True)
        AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True)
        EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False)
        # only for ee
        LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None)
        EventsField = EventField(lower=False, batch_first=True)
        SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True)

        if self.a.amr:
            colcc = 'simple-parsing'
        else:
            colcc = 'combined-parsing'
        print(colcc)

        train_ee_set = ACE2005Dataset(path=self.a.train_ee,
                                   fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
                                           "pos-tags": ("POSTAGS", PosTagsField),
                                           "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                           colcc: ("ADJM", AdjMatrixField),
                                           "golden-event-mentions": ("LABEL", LabelField),
                                           "all-events": ("EVENT", EventsField),
                                           "all-entities": ("ENTITIES", EntitiesField)},
                                   amr=self.a.amr, keep_events=1)

        dev_ee_set = ACE2005Dataset(path=self.a.dev_ee,
                                 fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
                                         "pos-tags": ("POSTAGS", PosTagsField),
                                         "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                         colcc: ("ADJM", AdjMatrixField),
                                         "golden-event-mentions": ("LABEL", LabelField),
                                         "all-events": ("EVENT", EventsField),
                                         "all-entities": ("ENTITIES", EntitiesField)},
                                 amr=self.a.amr, keep_events=0)

        test_ee_set = ACE2005Dataset(path=self.a.test_ee,
                                  fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
                                          "pos-tags": ("POSTAGS", PosTagsField),
                                          "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                          colcc: ("ADJM", AdjMatrixField),
                                          "golden-event-mentions": ("LABEL", LabelField),
                                          "all-events": ("EVENT", EventsField),
                                          "all-entities": ("ENTITIES", EntitiesField)},
                                  amr=self.a.amr, keep_events=0)

        if self.a.webd:
            pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15))
            LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL, vectors=pretrained_embedding)
            EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT, vectors=pretrained_embedding)
        else:
            LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL)
            EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT)

        # add role mask
        self.a.role_mask = event_role_mask(self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi,
                                           EventsField.vocab.stoi, self.device)

        ####################    loading SR dataset   ####################
        # both for grounding and sr
        if self.a.train_sr:
            log('loading corpus from %s' % self.a.train_sr)

        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])

        vocab_noun = Vocab(os.path.join(self.a.vocab, 'vocab_situation_noun.pkl'), load=True)
        vocab_role = Vocab(os.path.join(self.a.vocab, 'vocab_situation_role.pkl'), load=True)
        vocab_verb = Vocab(os.path.join(self.a.vocab, 'vocab_situation_verb.pkl'), load=True)

        # train_sr_loader = imsitu_loader(self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, self.a.imsitu_ontology_file,
        #                             self.a.train_sr, self.a.verb_mapping_file, self.a.role_mapping_file,
        #                             self.a.object_class_map_file, self.a.object_detection_pkl_file,
        #                             self.a.object_detection_threshold,
        #                             transform, self.a.batch, shuffle=self.a.shuffle, num_workers=1)  #self.a.shuffle
        # dev_sr_loader = imsitu_loader(self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, self.a.imsitu_ontology_file,
        #                             self.a.dev_sr, self.a.verb_mapping_file, self.a.role_mapping_file,
        #                             self.a.object_class_map_file, self.a.object_detection_pkl_file,
        #                             self.a.object_detection_threshold,
        #                             transform, self.a.batch, shuffle=self.a.shuffle, num_workers=1)
        # test_sr_loader = imsitu_loader(self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, self.a.imsitu_ontology_file,
        #                             self.a.test_sr, self.a.verb_mapping_file, self.a.role_mapping_file,
        #                             self.a.object_class_map_file, self.a.object_detection_pkl_file,
        #                             self.a.object_detection_threshold,
        #                             transform, self.a.batch, shuffle=self.a.shuffle, num_workers=1)
        train_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb,
                                     LabelField.vocab.stoi, EventsField.vocab.stoi,
                                     self.a.imsitu_ontology_file,
                                     self.a.train_sr, self.a.verb_mapping_file,
                                     self.a.object_class_map_file, self.a.object_detection_pkl_file,
                                     self.a.object_detection_threshold,
                                     transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
                                     load_object=self.a.add_object, filter_place=self.a.filter_place)
        dev_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb,
                                   LabelField.vocab.stoi, EventsField.vocab.stoi,
                                   self.a.imsitu_ontology_file,
                                   self.a.dev_sr, self.a.verb_mapping_file,
                                   self.a.object_class_map_file, self.a.object_detection_pkl_file,
                                   self.a.object_detection_threshold,
                                   transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
                                   load_object=self.a.add_object, filter_place=self.a.filter_place)
        test_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb,
                                    LabelField.vocab.stoi, EventsField.vocab.stoi,
                                    self.a.imsitu_ontology_file,
                                    self.a.test_sr, self.a.verb_mapping_file,
                                    self.a.object_class_map_file, self.a.object_detection_pkl_file,
                                    self.a.object_detection_threshold,
                                    transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
                                    load_object=self.a.add_object, filter_place=self.a.filter_place)


        ####################    loading grounding dataset   ####################
        if self.a.train_grounding:
            log('loading grounding corpus from %s' % self.a.train_grounding)

        # only for grounding
        IMAGEIDField = SparseField(sequential=False, use_vocab=False, batch_first=True)
        SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True)
        # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True)

        train_grounding_set = GroundingDataset(path=self.a.train_grounding,
                                               img_dir=self.a.img_dir_grounding,
                                               fields={"id": ("IMAGEID", IMAGEIDField),
                                                       "sentence_id": ("SENTID", SENTIDField),
                                                       "words": ("WORDS", WordsField),
                                                       "pos-tags": ("POSTAGS", PosTagsField),
                                                       "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                                       colcc: ("ADJM", AdjMatrixField),
                                                       "all-entities": ("ENTITIES", EntitiesField),
                                                       # "image": ("IMAGE", IMAGEField),
                                                       },
                                               transform=transform,
                                               amr=self.a.amr,
                                               load_object=self.a.add_object,
                                               object_ontology_file=self.a.object_class_map_file,
                                               object_detection_pkl_file=self.a.object_detection_pkl_file_g,
                                               object_detection_threshold=self.a.object_detection_threshold,
                                               )

        dev_grounding_set = GroundingDataset(path=self.a.dev_grounding,
                                             img_dir=self.a.img_dir_grounding,
                                             fields={"id": ("IMAGEID", IMAGEIDField),
                                                     "sentence_id": ("SENTID", SENTIDField),
                                                     "words": ("WORDS", WordsField),
                                                     "pos-tags": ("POSTAGS", PosTagsField),
                                                     "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                                     colcc: ("ADJM", AdjMatrixField),
                                                     "all-entities": ("ENTITIES", EntitiesField),
                                                     # "image": ("IMAGE", IMAGEField),
                                                     },
                                             transform=transform,
                                             amr=self.a.amr,
                                             load_object=self.a.add_object,
                                             object_ontology_file=self.a.object_class_map_file,
                                             object_detection_pkl_file=self.a.object_detection_pkl_file_g,
                                             object_detection_threshold=self.a.object_detection_threshold,
                                             )

        test_grounding_set = GroundingDataset(path=self.a.test_grounding,
                                              img_dir=self.a.img_dir_grounding,
                                              fields={"id": ("IMAGEID", IMAGEIDField),
                                                      "sentence_id": ("SENTID", SENTIDField),
                                                      "words": ("WORDS", WordsField),
                                                      "pos-tags": ("POSTAGS", PosTagsField),
                                                      "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                                      colcc: ("ADJM", AdjMatrixField),
                                                      "all-entities": ("ENTITIES", EntitiesField),
                                                      # "image": ("IMAGE", IMAGEField),
                                                      },
                                              transform=transform,
                                              amr=self.a.amr,
                                              load_object=self.a.add_object,
                                              object_ontology_file=self.a.object_class_map_file,
                                              object_detection_pkl_file=self.a.object_detection_pkl_file_g,
                                              object_detection_threshold=self.a.object_detection_threshold,
                                              )

        ####################    build vocabulary   ####################

        if self.a.webd:
            pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15))
            WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS, vectors=pretrained_embedding)
        else:
            WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS)
        PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS, train_grounding_set.POSTAGS, dev_grounding_set.POSTAGS)
        EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS,  train_grounding_set.ENTITYLABELS, dev_grounding_set.ENTITYLABELS)

        consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME]
        # print("O label is", consts.O_LABEL)
        consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME]
        # print("O label for AE is", consts.ROLE_O_LABEL)

        dev_ee_set1 = ACE2005Dataset(path=self.a.dev_ee,
                                  fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
                                          "pos-tags": ("POSTAGS", PosTagsField),
                                          "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                          colcc: ("ADJM", AdjMatrixField),
                                          "golden-event-mentions": ("LABEL", LabelField),
                                          "all-events": ("EVENT", EventsField),
                                          "all-entities": ("ENTITIES", EntitiesField)},
                                  amr=self.a.amr, keep_events=1, only_keep=True)

        test_ee_set1 = ACE2005Dataset(path=self.a.test_ee,
                                   fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
                                           "pos-tags": ("POSTAGS", PosTagsField),
                                           "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                           colcc: ("ADJM", AdjMatrixField),
                                           "golden-event-mentions": ("LABEL", LabelField),
                                           "all-events": ("EVENT", EventsField),
                                           "all-entities": ("ENTITIES", EntitiesField)},
                                   amr=self.a.amr, keep_events=1, only_keep=True)
        print("train set length", len(train_ee_set))

        print("dev set length", len(dev_ee_set))
        print("dev set 1/1 length", len(dev_ee_set1))

        print("test set length", len(test_ee_set))
        print("test set 1/1 length", len(test_ee_set1))

        # sr model initialization
        if not self.a.sr_hps_path:
            self.a.sr_hps = eval(self.a.sr_hps)
        embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to(self.device)
        embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to(self.device)
        embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to(self.device)
        if "wvemb_size" not in self.a.sr_hps:
            self.a.sr_hps["wvemb_size"] = len(vocab_verb.id2word)
        if "wremb_size" not in self.a.sr_hps:
            self.a.sr_hps["wremb_size"] = len(vocab_role.id2word)
        if "wnemb_size" not in self.a.sr_hps:
            self.a.sr_hps["wnemb_size"] = len(vocab_noun.id2word)

        self.a.ee_label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5
        self.a.ee_label_weight[consts.O_LABEL] = 1.0
        self.a.ee_arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5
        self.a.ee_hps = eval(self.a.ee_hps)
        if "wemb_size" not in self.a.ee_hps:
            self.a.ee_hps["wemb_size"] = len(WordsField.vocab.itos)
        if "pemb_size" not in self.a.ee_hps:
            self.a.ee_hps["pemb_size"] = len(PosTagsField.vocab.itos)
        if "psemb_size" not in self.a.ee_hps:
            # self.a.ee_hps["psemb_size"] = max([train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2
            self.a.ee_hps["psemb_size"] = max([train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest(), train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2
        if "eemb_size" not in self.a.ee_hps:
            self.a.ee_hps["eemb_size"] = len(EntityLabelsField.vocab.itos)
        if "oc" not in self.a.ee_hps:
            self.a.ee_hps["oc"] = len(LabelField.vocab.itos)
        if "ae_oc" not in self.a.ee_hps:
            self.a.ee_hps["ae_oc"] = len(EventsField.vocab.itos)
        if "oc" not in self.a.sr_hps:
            self.a.sr_hps["oc"] = len(LabelField.vocab.itos)
        if "ae_oc" not in self.a.sr_hps:
            self.a.sr_hps["ae_oc"] = len(EventsField.vocab.itos)

        ee_tester = EDTester(LabelField.vocab.itos, EventsField.vocab.itos, self.a.ignore_time_test)
        sr_tester = SRTester()
        g_tester = GroundingTester()
        j_tester = JointTester(self.a.ignore_place_sr_test, self.a.ignore_time_test)

        ace_classifier = ACEClassifier(2 * self.a.ee_hps["lstm_dim"], self.a.ee_hps["oc"], self.a.ee_hps["ae_oc"], self.device)

        if self.a.finetune_ee:
            log('init ee model from ' + self.a.finetune_ee)
            ee_model = load_ee_model(self.a.ee_hps, self.a.finetune_ee, WordsField.vocab.vectors, self.device, ace_classifier)
            log('ee model loaded, there are %i sets of params' % len(ee_model.parameters_requires_grads()))
        else:
            ee_model = load_ee_model(self.a.ee_hps, None, WordsField.vocab.vectors, self.device, ace_classifier)
            log('ee model created from scratch, there are %i sets of params' % len(ee_model.parameters_requires_grads()))

        if self.a.finetune_sr:
            log('init sr model from ' + self.a.finetune_sr)
            sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, self.a.finetune_sr, self.device, ace_classifier, add_object=self.a.add_object, load_partial=True)
            log('sr model loaded, there are %i sets of params' % len(sr_model.parameters_requires_grads()))
        else:
            sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, None, self.device, ace_classifier, add_object=self.a.add_object, load_partial=True)
            log('sr model created from scratch, there are %i sets of params' % len(sr_model.parameters_requires_grads()))

        model = GroundingModel(ee_model, sr_model, self.get_device())
        # ee_model = torch.nn.DataParallel(ee_model)
        # sr_model = torch.nn.DataParallel(sr_model)
        # model = torch.nn.DataParallel(model)

        if self.a.optimizer == "adadelta":
            optimizer_constructor = partial(torch.optim.Adadelta, params=model.parameters_requires_grads(),
                                            weight_decay=self.a.l2decay)
        elif self.a.optimizer == "adam":
            optimizer_constructor = partial(torch.optim.Adam, params=model.parameters_requires_grads(),
                                            weight_decay=self.a.l2decay)
        else:
            optimizer_constructor = partial(torch.optim.SGD, params=model.parameters_requires_grads(),
                                            weight_decay=self.a.l2decay,
                                            momentum=0.9)

        log('optimizer in use: %s' % str(self.a.optimizer))

        if not os.path.exists(self.a.out):
            os.mkdir(self.a.out)
        with open(os.path.join(self.a.out, "word.vec"), "wb") as f:
            pickle.dump(WordsField.vocab, f)
        with open(os.path.join(self.a.out, "pos.vec"), "wb") as f:
            pickle.dump(PosTagsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "entity.vec"), "wb") as f:
            pickle.dump(EntityLabelsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "label.vec"), "wb") as f:
            pickle.dump(LabelField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "role.vec"), "wb") as f:
            pickle.dump(EventsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "ee_hyps.json"), "w") as f:
            json.dump(self.a.ee_hps, f)
        with open(os.path.join(self.a.out, "sr_hyps.json"), "w") as f:
            json.dump(self.a.sr_hps, f)

        log('init complete\n')

        # ee mappings
        self.a.ee_word_i2s = WordsField.vocab.itos
        self.a.ee_label_i2s = LabelField.vocab.itos
        self.a.ee_role_i2s = EventsField.vocab.itos
        # sr mappings
        self.a.sr_word_i2s = vocab_noun.id2word
        self.a.sr_label_i2s = vocab_verb.id2word  # LabelField.vocab.itos
        self.a.sr_role_i2s = vocab_role.id2word
        writer = SummaryWriter(os.path.join(self.a.out, "exp"))
        self.a.writer = writer

        joint_train(
            model_ee=ee_model,
            model_sr=sr_model,
            model_g=model,
            train_set_g=train_grounding_set,
            dev_set_g=dev_grounding_set,
            test_set_g=test_grounding_set,
            train_set_ee=train_ee_set,
            dev_set_ee=dev_ee_set,
            test_set_ee=test_ee_set,
            train_set_sr=train_sr_set,
            dev_set_sr=dev_sr_set,
            test_set_sr=test_sr_set,
            optimizer_constructor=optimizer_constructor,
            epochs=self.a.epochs,
            ee_tester=ee_tester,
            sr_tester=sr_tester,
            g_tester=g_tester,
            j_tester=j_tester,
            parser=self.a,
            other_testsets={
                "dev ee 1/1": dev_ee_set1,
                "test ee 1/1": test_ee_set1,
            },
            transform=transform,
            vocab_objlabel=vocab_noun.word2id
        )
        log('Done!')