示例#1
0
    def run(self):
        print("Running on", self.a.device)
        self.set_device(self.a.device)

        np.random.seed(self.a.seed)
        torch.manual_seed(self.a.seed)
        torch.backends.cudnn.benchmark = True

        # create training set
        if self.a.train:
            log('loading corpus from %s' % self.a.train)

        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

        IMAGEIDField = SparseField(sequential=False,
                                   use_vocab=False,
                                   batch_first=True)
        SENTIDField = SparseField(sequential=False,
                                  use_vocab=False,
                                  batch_first=True)
        # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True)
        WordsField = Field(lower=True, include_lengths=True, batch_first=True)
        PosTagsField = Field(lower=True, batch_first=True)
        EntityLabelsField = MultiTokenField(lower=False, batch_first=True)
        AdjMatrixField = SparseField(sequential=False,
                                     use_vocab=False,
                                     batch_first=True)
        EntitiesField = EntityField(lower=False,
                                    batch_first=True,
                                    use_vocab=False)

        if self.a.amr:
            colcc = 'amr-colcc'
        else:
            colcc = 'stanford-colcc'
        print(colcc)

        train_set = GroundingDataset(
            path=self.a.train,
            img_dir=self.a.img_dir,
            fields={
                "id": ("IMAGEID", IMAGEIDField),
                "sentence_id": ("SENTID", SENTIDField),
                "words": ("WORDS", WordsField),
                "pos-tags": ("POSTAGS", PosTagsField),
                "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                colcc: ("ADJM", AdjMatrixField),
                "all-entities": ("ENTITIES", EntitiesField),
                # "image": ("IMAGE", IMAGEField),
            },
            transform=transform,
            amr=self.a.amr,
            load_object=self.a.add_object,
            object_ontology_file=self.a.object_class_map_file,
            object_detection_pkl_file=self.a.object_detection_pkl_file,
            object_detection_threshold=self.a.object_detection_threshold,
        )

        dev_set = GroundingDataset(
            path=self.a.dev,
            img_dir=self.a.img_dir,
            fields={
                "id": ("IMAGEID", IMAGEIDField),
                "sentence_id": ("SENTID", SENTIDField),
                "words": ("WORDS", WordsField),
                "pos-tags": ("POSTAGS", PosTagsField),
                "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                colcc: ("ADJM", AdjMatrixField),
                "all-entities": ("ENTITIES", EntitiesField),
                # "image": ("IMAGE", IMAGEField),
            },
            transform=transform,
            amr=self.a.amr,
            load_object=self.a.add_object,
            object_ontology_file=self.a.object_class_map_file,
            object_detection_pkl_file=self.a.object_detection_pkl_file,
            object_detection_threshold=self.a.object_detection_threshold,
        )

        test_set = GroundingDataset(
            path=self.a.test,
            img_dir=self.a.img_dir,
            fields={
                "id": ("IMAGEID", IMAGEIDField),
                "sentence_id": ("SENTID", SENTIDField),
                "words": ("WORDS", WordsField),
                "pos-tags": ("POSTAGS", PosTagsField),
                "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                colcc: ("ADJM", AdjMatrixField),
                "all-entities": ("ENTITIES", EntitiesField),
                # "image": ("IMAGE", IMAGEField),
            },
            transform=transform,
            amr=self.a.amr,
            load_object=self.a.add_object,
            object_ontology_file=self.a.object_class_map_file,
            object_detection_pkl_file=self.a.object_detection_pkl_file,
            object_detection_threshold=self.a.object_detection_threshold,
        )

        if self.a.webd:
            pretrained_embedding = Vectors(self.a.webd,
                                           ".",
                                           unk_init=partial(
                                               torch.nn.init.uniform_,
                                               a=-0.15,
                                               b=0.15))
            WordsField.build_vocab(train_set.WORDS,
                                   dev_set.WORDS,
                                   vectors=pretrained_embedding)
        else:
            WordsField.build_vocab(train_set.WORDS, dev_set.WORDS)
        # WordsField.build_vocab(train_set.WORDS, dev_set.WORDS)
        PosTagsField.build_vocab(train_set.POSTAGS, dev_set.POSTAGS)
        EntityLabelsField.build_vocab(train_set.ENTITYLABELS,
                                      dev_set.ENTITYLABELS)

        # sr model initialization
        self.a.sr_hps = eval(self.a.sr_hps)
        vocab_noun = Vocab(os.path.join(self.a.vocab,
                                        'vocab_situation_noun.pkl'),
                           load=True)
        vocab_role = Vocab(os.path.join(self.a.vocab,
                                        'vocab_situation_role.pkl'),
                           load=True)
        vocab_verb = Vocab(os.path.join(self.a.vocab,
                                        'vocab_situation_verb.pkl'),
                           load=True)
        embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to(
            self.device)
        embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to(
            self.device)
        embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to(
            self.device)
        if "wvemb_size" not in self.a.sr_hps:
            self.a.sr_hps["wvemb_size"] = len(vocab_verb.id2word)
        if "wremb_size" not in self.a.sr_hps:
            self.a.sr_hps["wremb_size"] = len(vocab_role.id2word)
        if "wnemb_size" not in self.a.sr_hps:
            self.a.sr_hps["wnemb_size"] = len(vocab_noun.id2word)
        if "ae_oc" not in self.a.sr_hps:
            self.a.sr_hps["ae_oc"] = len(vocab_role.id2word)

        self.a.ee_hps = eval(self.a.ee_hps)
        if "wemb_size" not in self.a.ee_hps:
            self.a.ee_hps["wemb_size"] = len(WordsField.vocab.itos)
        if "pemb_size" not in self.a.ee_hps:
            self.a.ee_hps["pemb_size"] = len(PosTagsField.vocab.itos)
        if "psemb_size" not in self.a.ee_hps:
            self.a.ee_hps["psemb_size"] = max(
                [train_set.longest(),
                 dev_set.longest(),
                 test_set.longest()]) + 2
        if "eemb_size" not in self.a.ee_hps:
            self.a.ee_hps["eemb_size"] = len(EntityLabelsField.vocab.itos)
        if "oc" not in self.a.ee_hps:
            self.a.ee_hps["oc"] = 36  #???
        if "ae_oc" not in self.a.ee_hps:
            self.a.ee_hps["ae_oc"] = 20  #???

        tester = self.get_tester()

        if self.a.finetune_sr:
            log('init sr model from ' + self.a.finetune_sr)
            sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun,
                                     embeddingMatrix_verb,
                                     embeddingMatrix_role, self.a.finetune_sr,
                                     self.device)
            log('sr model loaded, there are %i sets of params' %
                len(sr_model.parameters_requires_grads()))
        else:
            sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun,
                                     embeddingMatrix_verb,
                                     embeddingMatrix_role, None, self.device)
            log('sr model created from scratch, there are %i sets of params' %
                len(sr_model.parameters_requires_grads()))

        if self.a.finetune_ee:
            log('init model from ' + self.a.finetune_ee)
            ee_model = load_ee_model(self.a.ee_hps, self.a.finetune_ee,
                                     WordsField.vocab.vectors, self.device)
            log('model loaded, there are %i sets of params' %
                len(ee_model.parameters_requires_grads()))
        else:
            ee_model = load_ee_model(self.a.ee_hps, None,
                                     WordsField.vocab.vectors, self.device)
            log('model created from scratch, there are %i sets of params' %
                len(ee_model.parameters_requires_grads()))

        model = GroundingModel(ee_model, sr_model, self.get_device())

        if self.a.optimizer == "adadelta":
            optimizer_constructor = partial(
                torch.optim.Adadelta,
                params=model.parameters_requires_grads(),
                weight_decay=self.a.l2decay)
        elif self.a.optimizer == "adam":
            optimizer_constructor = partial(
                torch.optim.Adam,
                params=model.parameters_requires_grads(),
                weight_decay=self.a.l2decay)
        else:
            optimizer_constructor = partial(
                torch.optim.SGD,
                params=model.parameters_requires_grads(),
                weight_decay=self.a.l2decay,
                momentum=0.9)

        log('optimizer in use: %s' % str(self.a.optimizer))

        if not os.path.exists(self.a.out):
            os.mkdir(self.a.out)
        with open(os.path.join(self.a.out, "word.vec"), "wb") as f:
            pickle.dump(WordsField.vocab, f)
        with open(os.path.join(self.a.out, "pos.vec"), "wb") as f:
            pickle.dump(PosTagsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "entity.vec"), "wb") as f:
            pickle.dump(EntityLabelsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "ee_hyps.json"), "w") as f:
            json.dump(self.a.ee_hps, f)
        with open(os.path.join(self.a.out, "sr_hyps.json"), "w") as f:
            json.dump(self.a.sr_hps, f)

        log('init complete\n')

        self.a.word_i2s = vocab_noun.id2word
        self.a.label_i2s = vocab_verb.id2word  # LabelField.vocab.itos
        self.a.role_i2s = vocab_role.id2word
        self.a.word_i2s = WordsField.vocab.itos
        # self.a.label_i2s = LabelField.vocab.itos
        # self.a.role_i2s = EventsField.vocab.itos
        writer = SummaryWriter(os.path.join(self.a.out, "exp"))
        self.a.writer = writer

        grounding_train(
            model=model,
            train_set=train_set,
            dev_set=dev_set,
            test_set=test_set,
            optimizer_constructor=optimizer_constructor,
            epochs=self.a.epochs,
            tester=tester,
            parser=self.a,
            other_testsets={
                # "dev 1/1": dev_set1,
                # "test 1/1": test_set1,
            },
            transform=transform,
            vocab_objlabel=vocab_noun.word2id)
        log('Done!')
示例#2
0
    def run(self):
        print("Running on", self.a.device)
        self.set_device(self.a.device)

        np.random.seed(self.a.seed)
        torch.manual_seed(self.a.seed)
        torch.backends.cudnn.benchmark = True

        ####################    loading event extraction dataset   ####################
        if self.a.train_ee:
            log('loading event extraction corpus from %s' % self.a.train_ee)

        # both for grounding and ee
        WordsField = Field(lower=True, include_lengths=True, batch_first=True)
        PosTagsField = Field(lower=True, batch_first=True)
        EntityLabelsField = MultiTokenField(lower=False, batch_first=True)
        AdjMatrixField = SparseField(sequential=False, use_vocab=False, batch_first=True)
        EntitiesField = EntityField(lower=False, batch_first=True, use_vocab=False)
        # only for ee
        LabelField = Field(lower=False, batch_first=True, pad_token='0', unk_token=None)
        EventsField = EventField(lower=False, batch_first=True)
        SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True)

        if self.a.amr:
            colcc = 'simple-parsing'
        else:
            colcc = 'combined-parsing'
        print(colcc)

        train_ee_set = ACE2005Dataset(path=self.a.train_ee,
                                   fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
                                           "pos-tags": ("POSTAGS", PosTagsField),
                                           "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                           colcc: ("ADJM", AdjMatrixField),
                                           "golden-event-mentions": ("LABEL", LabelField),
                                           "all-events": ("EVENT", EventsField),
                                           "all-entities": ("ENTITIES", EntitiesField)},
                                   amr=self.a.amr, keep_events=1)

        dev_ee_set = ACE2005Dataset(path=self.a.dev_ee,
                                 fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
                                         "pos-tags": ("POSTAGS", PosTagsField),
                                         "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                         colcc: ("ADJM", AdjMatrixField),
                                         "golden-event-mentions": ("LABEL", LabelField),
                                         "all-events": ("EVENT", EventsField),
                                         "all-entities": ("ENTITIES", EntitiesField)},
                                 amr=self.a.amr, keep_events=0)

        test_ee_set = ACE2005Dataset(path=self.a.test_ee,
                                  fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
                                          "pos-tags": ("POSTAGS", PosTagsField),
                                          "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                          colcc: ("ADJM", AdjMatrixField),
                                          "golden-event-mentions": ("LABEL", LabelField),
                                          "all-events": ("EVENT", EventsField),
                                          "all-entities": ("ENTITIES", EntitiesField)},
                                  amr=self.a.amr, keep_events=0)

        if self.a.webd:
            pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15))
            LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL, vectors=pretrained_embedding)
            EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT, vectors=pretrained_embedding)
        else:
            LabelField.build_vocab(train_ee_set.LABEL, dev_ee_set.LABEL)
            EventsField.build_vocab(train_ee_set.EVENT, dev_ee_set.EVENT)

        # add role mask
        self.a.role_mask = event_role_mask(self.a.train_ee, self.a.dev_ee, LabelField.vocab.stoi,
                                           EventsField.vocab.stoi, self.device)

        ####################    loading SR dataset   ####################
        # both for grounding and sr
        if self.a.train_sr:
            log('loading corpus from %s' % self.a.train_sr)

        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])

        vocab_noun = Vocab(os.path.join(self.a.vocab, 'vocab_situation_noun.pkl'), load=True)
        vocab_role = Vocab(os.path.join(self.a.vocab, 'vocab_situation_role.pkl'), load=True)
        vocab_verb = Vocab(os.path.join(self.a.vocab, 'vocab_situation_verb.pkl'), load=True)

        # train_sr_loader = imsitu_loader(self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, self.a.imsitu_ontology_file,
        #                             self.a.train_sr, self.a.verb_mapping_file, self.a.role_mapping_file,
        #                             self.a.object_class_map_file, self.a.object_detection_pkl_file,
        #                             self.a.object_detection_threshold,
        #                             transform, self.a.batch, shuffle=self.a.shuffle, num_workers=1)  #self.a.shuffle
        # dev_sr_loader = imsitu_loader(self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, self.a.imsitu_ontology_file,
        #                             self.a.dev_sr, self.a.verb_mapping_file, self.a.role_mapping_file,
        #                             self.a.object_class_map_file, self.a.object_detection_pkl_file,
        #                             self.a.object_detection_threshold,
        #                             transform, self.a.batch, shuffle=self.a.shuffle, num_workers=1)
        # test_sr_loader = imsitu_loader(self.a.image_dir, self.vocab_noun, self.vocab_role, self.vocab_verb, self.a.imsitu_ontology_file,
        #                             self.a.test_sr, self.a.verb_mapping_file, self.a.role_mapping_file,
        #                             self.a.object_class_map_file, self.a.object_detection_pkl_file,
        #                             self.a.object_detection_threshold,
        #                             transform, self.a.batch, shuffle=self.a.shuffle, num_workers=1)
        train_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb,
                                     LabelField.vocab.stoi, EventsField.vocab.stoi,
                                     self.a.imsitu_ontology_file,
                                     self.a.train_sr, self.a.verb_mapping_file,
                                     self.a.object_class_map_file, self.a.object_detection_pkl_file,
                                     self.a.object_detection_threshold,
                                     transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
                                     load_object=self.a.add_object, filter_place=self.a.filter_place)
        dev_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb,
                                   LabelField.vocab.stoi, EventsField.vocab.stoi,
                                   self.a.imsitu_ontology_file,
                                   self.a.dev_sr, self.a.verb_mapping_file,
                                   self.a.object_class_map_file, self.a.object_detection_pkl_file,
                                   self.a.object_detection_threshold,
                                   transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
                                   load_object=self.a.add_object, filter_place=self.a.filter_place)
        test_sr_set = ImSituDataset(self.a.image_dir, vocab_noun, vocab_role, vocab_verb,
                                    LabelField.vocab.stoi, EventsField.vocab.stoi,
                                    self.a.imsitu_ontology_file,
                                    self.a.test_sr, self.a.verb_mapping_file,
                                    self.a.object_class_map_file, self.a.object_detection_pkl_file,
                                    self.a.object_detection_threshold,
                                    transform, filter_irrelevant_verbs=self.a.filter_irrelevant_verbs,
                                    load_object=self.a.add_object, filter_place=self.a.filter_place)


        ####################    loading grounding dataset   ####################
        if self.a.train_grounding:
            log('loading grounding corpus from %s' % self.a.train_grounding)

        # only for grounding
        IMAGEIDField = SparseField(sequential=False, use_vocab=False, batch_first=True)
        SENTIDField = SparseField(sequential=False, use_vocab=False, batch_first=True)
        # IMAGEField = SparseField(sequential=False, use_vocab=False, batch_first=True)

        train_grounding_set = GroundingDataset(path=self.a.train_grounding,
                                               img_dir=self.a.img_dir_grounding,
                                               fields={"id": ("IMAGEID", IMAGEIDField),
                                                       "sentence_id": ("SENTID", SENTIDField),
                                                       "words": ("WORDS", WordsField),
                                                       "pos-tags": ("POSTAGS", PosTagsField),
                                                       "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                                       colcc: ("ADJM", AdjMatrixField),
                                                       "all-entities": ("ENTITIES", EntitiesField),
                                                       # "image": ("IMAGE", IMAGEField),
                                                       },
                                               transform=transform,
                                               amr=self.a.amr,
                                               load_object=self.a.add_object,
                                               object_ontology_file=self.a.object_class_map_file,
                                               object_detection_pkl_file=self.a.object_detection_pkl_file_g,
                                               object_detection_threshold=self.a.object_detection_threshold,
                                               )

        dev_grounding_set = GroundingDataset(path=self.a.dev_grounding,
                                             img_dir=self.a.img_dir_grounding,
                                             fields={"id": ("IMAGEID", IMAGEIDField),
                                                     "sentence_id": ("SENTID", SENTIDField),
                                                     "words": ("WORDS", WordsField),
                                                     "pos-tags": ("POSTAGS", PosTagsField),
                                                     "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                                     colcc: ("ADJM", AdjMatrixField),
                                                     "all-entities": ("ENTITIES", EntitiesField),
                                                     # "image": ("IMAGE", IMAGEField),
                                                     },
                                             transform=transform,
                                             amr=self.a.amr,
                                             load_object=self.a.add_object,
                                             object_ontology_file=self.a.object_class_map_file,
                                             object_detection_pkl_file=self.a.object_detection_pkl_file_g,
                                             object_detection_threshold=self.a.object_detection_threshold,
                                             )

        test_grounding_set = GroundingDataset(path=self.a.test_grounding,
                                              img_dir=self.a.img_dir_grounding,
                                              fields={"id": ("IMAGEID", IMAGEIDField),
                                                      "sentence_id": ("SENTID", SENTIDField),
                                                      "words": ("WORDS", WordsField),
                                                      "pos-tags": ("POSTAGS", PosTagsField),
                                                      "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                                      colcc: ("ADJM", AdjMatrixField),
                                                      "all-entities": ("ENTITIES", EntitiesField),
                                                      # "image": ("IMAGE", IMAGEField),
                                                      },
                                              transform=transform,
                                              amr=self.a.amr,
                                              load_object=self.a.add_object,
                                              object_ontology_file=self.a.object_class_map_file,
                                              object_detection_pkl_file=self.a.object_detection_pkl_file_g,
                                              object_detection_threshold=self.a.object_detection_threshold,
                                              )

        ####################    build vocabulary   ####################

        if self.a.webd:
            pretrained_embedding = Vectors(self.a.webd, ".", unk_init=partial(torch.nn.init.uniform_, a=-0.15, b=0.15))
            WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS, vectors=pretrained_embedding)
        else:
            WordsField.build_vocab(train_ee_set.WORDS, dev_ee_set.WORDS, train_grounding_set.WORDS, dev_grounding_set.WORDS)
        PosTagsField.build_vocab(train_ee_set.POSTAGS, dev_ee_set.POSTAGS, train_grounding_set.POSTAGS, dev_grounding_set.POSTAGS)
        EntityLabelsField.build_vocab(train_ee_set.ENTITYLABELS, dev_ee_set.ENTITYLABELS,  train_grounding_set.ENTITYLABELS, dev_grounding_set.ENTITYLABELS)

        consts.O_LABEL = LabelField.vocab.stoi[consts.O_LABEL_NAME]
        # print("O label is", consts.O_LABEL)
        consts.ROLE_O_LABEL = EventsField.vocab.stoi[consts.ROLE_O_LABEL_NAME]
        # print("O label for AE is", consts.ROLE_O_LABEL)

        dev_ee_set1 = ACE2005Dataset(path=self.a.dev_ee,
                                  fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
                                          "pos-tags": ("POSTAGS", PosTagsField),
                                          "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                          colcc: ("ADJM", AdjMatrixField),
                                          "golden-event-mentions": ("LABEL", LabelField),
                                          "all-events": ("EVENT", EventsField),
                                          "all-entities": ("ENTITIES", EntitiesField)},
                                  amr=self.a.amr, keep_events=1, only_keep=True)

        test_ee_set1 = ACE2005Dataset(path=self.a.test_ee,
                                   fields={"sentence_id": ("SENTID", SENTIDField), "words": ("WORDS", WordsField),
                                           "pos-tags": ("POSTAGS", PosTagsField),
                                           "golden-entity-mentions": ("ENTITYLABELS", EntityLabelsField),
                                           colcc: ("ADJM", AdjMatrixField),
                                           "golden-event-mentions": ("LABEL", LabelField),
                                           "all-events": ("EVENT", EventsField),
                                           "all-entities": ("ENTITIES", EntitiesField)},
                                   amr=self.a.amr, keep_events=1, only_keep=True)
        print("train set length", len(train_ee_set))

        print("dev set length", len(dev_ee_set))
        print("dev set 1/1 length", len(dev_ee_set1))

        print("test set length", len(test_ee_set))
        print("test set 1/1 length", len(test_ee_set1))

        # sr model initialization
        if not self.a.sr_hps_path:
            self.a.sr_hps = eval(self.a.sr_hps)
        embeddingMatrix_noun = torch.FloatTensor(np.load(self.a.wnebd)).to(self.device)
        embeddingMatrix_verb = torch.FloatTensor(np.load(self.a.wvebd)).to(self.device)
        embeddingMatrix_role = torch.FloatTensor(np.load(self.a.wrebd)).to(self.device)
        if "wvemb_size" not in self.a.sr_hps:
            self.a.sr_hps["wvemb_size"] = len(vocab_verb.id2word)
        if "wremb_size" not in self.a.sr_hps:
            self.a.sr_hps["wremb_size"] = len(vocab_role.id2word)
        if "wnemb_size" not in self.a.sr_hps:
            self.a.sr_hps["wnemb_size"] = len(vocab_noun.id2word)

        self.a.ee_label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5
        self.a.ee_label_weight[consts.O_LABEL] = 1.0
        self.a.ee_arg_weight = torch.ones([len(EventsField.vocab.itos)]) * 5
        self.a.ee_hps = eval(self.a.ee_hps)
        if "wemb_size" not in self.a.ee_hps:
            self.a.ee_hps["wemb_size"] = len(WordsField.vocab.itos)
        if "pemb_size" not in self.a.ee_hps:
            self.a.ee_hps["pemb_size"] = len(PosTagsField.vocab.itos)
        if "psemb_size" not in self.a.ee_hps:
            # self.a.ee_hps["psemb_size"] = max([train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2
            self.a.ee_hps["psemb_size"] = max([train_ee_set.longest(), dev_ee_set.longest(), test_ee_set.longest(), train_grounding_set.longest(), dev_grounding_set.longest(), test_grounding_set.longest()]) + 2
        if "eemb_size" not in self.a.ee_hps:
            self.a.ee_hps["eemb_size"] = len(EntityLabelsField.vocab.itos)
        if "oc" not in self.a.ee_hps:
            self.a.ee_hps["oc"] = len(LabelField.vocab.itos)
        if "ae_oc" not in self.a.ee_hps:
            self.a.ee_hps["ae_oc"] = len(EventsField.vocab.itos)
        if "oc" not in self.a.sr_hps:
            self.a.sr_hps["oc"] = len(LabelField.vocab.itos)
        if "ae_oc" not in self.a.sr_hps:
            self.a.sr_hps["ae_oc"] = len(EventsField.vocab.itos)

        ee_tester = EDTester(LabelField.vocab.itos, EventsField.vocab.itos, self.a.ignore_time_test)
        sr_tester = SRTester()
        g_tester = GroundingTester()
        j_tester = JointTester(self.a.ignore_place_sr_test, self.a.ignore_time_test)

        ace_classifier = ACEClassifier(2 * self.a.ee_hps["lstm_dim"], self.a.ee_hps["oc"], self.a.ee_hps["ae_oc"], self.device)

        if self.a.finetune_ee:
            log('init ee model from ' + self.a.finetune_ee)
            ee_model = load_ee_model(self.a.ee_hps, self.a.finetune_ee, WordsField.vocab.vectors, self.device, ace_classifier)
            log('ee model loaded, there are %i sets of params' % len(ee_model.parameters_requires_grads()))
        else:
            ee_model = load_ee_model(self.a.ee_hps, None, WordsField.vocab.vectors, self.device, ace_classifier)
            log('ee model created from scratch, there are %i sets of params' % len(ee_model.parameters_requires_grads()))

        if self.a.finetune_sr:
            log('init sr model from ' + self.a.finetune_sr)
            sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, self.a.finetune_sr, self.device, ace_classifier, add_object=self.a.add_object, load_partial=True)
            log('sr model loaded, there are %i sets of params' % len(sr_model.parameters_requires_grads()))
        else:
            sr_model = load_sr_model(self.a.sr_hps, embeddingMatrix_noun, embeddingMatrix_verb, embeddingMatrix_role, None, self.device, ace_classifier, add_object=self.a.add_object, load_partial=True)
            log('sr model created from scratch, there are %i sets of params' % len(sr_model.parameters_requires_grads()))

        model = GroundingModel(ee_model, sr_model, self.get_device())
        # ee_model = torch.nn.DataParallel(ee_model)
        # sr_model = torch.nn.DataParallel(sr_model)
        # model = torch.nn.DataParallel(model)

        if self.a.optimizer == "adadelta":
            optimizer_constructor = partial(torch.optim.Adadelta, params=model.parameters_requires_grads(),
                                            weight_decay=self.a.l2decay)
        elif self.a.optimizer == "adam":
            optimizer_constructor = partial(torch.optim.Adam, params=model.parameters_requires_grads(),
                                            weight_decay=self.a.l2decay)
        else:
            optimizer_constructor = partial(torch.optim.SGD, params=model.parameters_requires_grads(),
                                            weight_decay=self.a.l2decay,
                                            momentum=0.9)

        log('optimizer in use: %s' % str(self.a.optimizer))

        if not os.path.exists(self.a.out):
            os.mkdir(self.a.out)
        with open(os.path.join(self.a.out, "word.vec"), "wb") as f:
            pickle.dump(WordsField.vocab, f)
        with open(os.path.join(self.a.out, "pos.vec"), "wb") as f:
            pickle.dump(PosTagsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "entity.vec"), "wb") as f:
            pickle.dump(EntityLabelsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "label.vec"), "wb") as f:
            pickle.dump(LabelField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "role.vec"), "wb") as f:
            pickle.dump(EventsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "ee_hyps.json"), "w") as f:
            json.dump(self.a.ee_hps, f)
        with open(os.path.join(self.a.out, "sr_hyps.json"), "w") as f:
            json.dump(self.a.sr_hps, f)

        log('init complete\n')

        # ee mappings
        self.a.ee_word_i2s = WordsField.vocab.itos
        self.a.ee_label_i2s = LabelField.vocab.itos
        self.a.ee_role_i2s = EventsField.vocab.itos
        # sr mappings
        self.a.sr_word_i2s = vocab_noun.id2word
        self.a.sr_label_i2s = vocab_verb.id2word  # LabelField.vocab.itos
        self.a.sr_role_i2s = vocab_role.id2word
        writer = SummaryWriter(os.path.join(self.a.out, "exp"))
        self.a.writer = writer

        joint_train(
            model_ee=ee_model,
            model_sr=sr_model,
            model_g=model,
            train_set_g=train_grounding_set,
            dev_set_g=dev_grounding_set,
            test_set_g=test_grounding_set,
            train_set_ee=train_ee_set,
            dev_set_ee=dev_ee_set,
            test_set_ee=test_ee_set,
            train_set_sr=train_sr_set,
            dev_set_sr=dev_sr_set,
            test_set_sr=test_sr_set,
            optimizer_constructor=optimizer_constructor,
            epochs=self.a.epochs,
            ee_tester=ee_tester,
            sr_tester=sr_tester,
            g_tester=g_tester,
            j_tester=j_tester,
            parser=self.a,
            other_testsets={
                "dev ee 1/1": dev_ee_set1,
                "test ee 1/1": test_ee_set1,
            },
            transform=transform,
            vocab_objlabel=vocab_noun.word2id
        )
        log('Done!')