def run(self):
        print("Running on", self.a.device)
        self.set_device(self.a.device)

        np.random.seed(self.a.seed)
        torch.manual_seed(self.a.seed)

        # create training set
        if self.a.train:
            log('loading corpus from %s' % self.a.train)

        WordsField = Field(lower=True, include_lengths=True, batch_first=True)
        PosTagsField = Field(lower=True, batch_first=True)
        EntityLabelsField = MultiTokenField(lower=False, batch_first=True)
        AdjMatrixField = SparseField(sequential=False,
                                     use_vocab=False,
                                     batch_first=True)
        LabelField = Field(lower=False,
                           batch_first=True,
                           pad_token=None,
                           unk_token=None)
        EventsField = EventField(lower=False, batch_first=True)
        EntitiesField = EntityField(lower=False,
                                    batch_first=True,
                                    use_vocab=False)

        train_set = ACE2005Dataset(path=self.a.train,
                                   fields={
                                       "words": ("WORDS", WordsField),
                                       "pos-tags": ("POSTAGS", PosTagsField),
                                       "golden-entity-mentions":
                                       ("ENTITYLABELS", EntityLabelsField),
                                       "stanford-colcc":
                                       ("ADJM", AdjMatrixField),
                                       "golden-event-mentions":
                                       ("LABEL", LabelField),
                                       "all-events": ("EVENT", EventsField),
                                       "all-entities":
                                       ("ENTITIES", EntitiesField)
                                   },
                                   keep_events=1)

        dev_set = ACE2005Dataset(path=self.a.dev,
                                 fields={
                                     "words": ("WORDS", WordsField),
                                     "pos-tags": ("POSTAGS", PosTagsField),
                                     "golden-entity-mentions":
                                     ("ENTITYLABELS", EntityLabelsField),
                                     "stanford-colcc":
                                     ("ADJM", AdjMatrixField),
                                     "golden-event-mentions":
                                     ("LABEL", LabelField),
                                     "all-events": ("EVENT", EventsField),
                                     "all-entities":
                                     ("ENTITIES", EntitiesField)
                                 },
                                 keep_events=0)

        test_set = ACE2005Dataset(path=self.a.test,
                                  fields={
                                      "words": ("WORDS", WordsField),
                                      "pos-tags": ("POSTAGS", PosTagsField),
                                      "golden-entity-mentions":
                                      ("ENTITYLABELS", EntityLabelsField),
                                      "stanford-colcc":
                                      ("ADJM", AdjMatrixField),
                                      "golden-event-mentions":
                                      ("LABEL", LabelField),
                                      "all-events": ("EVENT", EventsField),
                                      "all-entities":
                                      ("ENTITIES", EntitiesField)
                                  },
                                  keep_events=0)

        if self.a.webd:
            pretrained_embedding = Vectors(self.a.webd,
                                           ".",
                                           unk_init=partial(
                                               torch.nn.init.uniform_,
                                               a=-0.15,
                                               b=0.15))
            WordsField.build_vocab(train_set.WORDS,
                                   dev_set.WORDS,
                                   vectors=pretrained_embedding)
        else:
            WordsField.build_vocab(train_set.WORDS, dev_set.WORDS)
        PosTagsField.build_vocab(train_set.POSTAGS, dev_set.POSTAGS)
        EntityLabelsField.build_vocab(train_set.ENTITYLABELS,
                                      dev_set.ENTITYLABELS)
        LabelField.build_vocab(train_set.LABEL, dev_set.LABEL)
        EventsField.build_vocab(train_set.EVENT, dev_set.EVENT)

        consts.O_LABEL = LabelField.vocab.stoi["O"]
        # print("O label is", consts.O_LABEL)
        consts.ROLE_O_LABEL = EventsField.vocab.stoi["OTHER"]
        # print("O label for AE is", consts.ROLE_O_LABEL)

        dev_set1 = ACE2005Dataset(path=self.a.dev,
                                  fields={
                                      "words": ("WORDS", WordsField),
                                      "pos-tags": ("POSTAGS", PosTagsField),
                                      "golden-entity-mentions":
                                      ("ENTITYLABELS", EntityLabelsField),
                                      "stanford-colcc":
                                      ("ADJM", AdjMatrixField),
                                      "golden-event-mentions":
                                      ("LABEL", LabelField),
                                      "all-events": ("EVENT", EventsField),
                                      "all-entities":
                                      ("ENTITIES", EntitiesField)
                                  },
                                  keep_events=1,
                                  only_keep=True)

        test_set1 = ACE2005Dataset(path=self.a.test,
                                   fields={
                                       "words": ("WORDS", WordsField),
                                       "pos-tags": ("POSTAGS", PosTagsField),
                                       "golden-entity-mentions":
                                       ("ENTITYLABELS", EntityLabelsField),
                                       "stanford-colcc":
                                       ("ADJM", AdjMatrixField),
                                       "golden-event-mentions":
                                       ("LABEL", LabelField),
                                       "all-events": ("EVENT", EventsField),
                                       "all-entities":
                                       ("ENTITIES", EntitiesField)
                                   },
                                   keep_events=1,
                                   only_keep=True)

        print("dev set length", len(dev_set))
        print("dev set 1/1 length", len(dev_set1))

        print("test set length", len(test_set))
        print("test set 1/1 length", len(test_set1))

        self.a.label_weight = torch.ones([len(LabelField.vocab.itos)]) * 5
        self.a.label_weight[consts.O_LABEL] = 1.0

        print(self.a.hps)
        # self.a.hps = eval(self.a.hps)
        if not self.a.hps:
            self.a.hps = {}

        if "wemb_size" not in self.a.hps:
            self.a.hps["wemb_size"] = len(WordsField.vocab.itos)
        if "wemb_dim" not in self.a.hps:
            self.a.hps["wemb_dim"] = 300
        if "wemb_dp" not in self.a.hps:
            self.a.hps["wemb_dp"] = 0.5
        if "pemb_size" not in self.a.hps:
            self.a.hps["pemb_size"] = len(PosTagsField.vocab.itos)
        if "psemb_size" not in self.a.hps:
            self.a.hps["psemb_size"] = max(
                [train_set.longest(),
                 dev_set.longest(),
                 test_set.longest()]) + 2
        if "psemb_dim" not in self.a.hps:
            self.a.hps["psemb_dim"] = 50
        if "eemb_size" not in self.a.hps:
            self.a.hps["eemb_size"] = len(EntityLabelsField.vocab.itos)
        if "oc" not in self.a.hps:
            self.a.hps["oc"] = len(LabelField.vocab.itos)
        if "ae_oc" not in self.a.hps:
            self.a.hps["ae_oc"] = len(EventsField.vocab.itos)

        self.a.hps["pemb_dim"] = 50
        self.a.hps["psemb_dim"] = 50
        self.a.hps["eemb_dim"] = 50
        self.a.hps["lstm_dim"] = 50
        self.a.hps["psemb_dp"] = 0.5
        self.a.hps["pemb_dp"] = 0.5
        self.a.hps["eemb_dp"] = 0.5
        self.a.hps["lstm_dp"] = 0.5
        self.a.hps["wemb_ft"] = False
        self.a.hps["lstm_layers"] = 1
        self.a.hps["gcn_dp"] = 0.5
        self.a.hps["gcn_use_bn"] = False
        self.a.hps["gcn_layers"] = 3
        self.a.hps["gcn_et"] = 3
        self.a.hps["use_highway"] = True
        self.a.hps["sa_dim"] = 300
        self.a.hps["loss_alpha"] = 5

        tester = self.get_tester(LabelField.vocab.itos)

        if self.a.finetune:
            log('init model from ' + self.a.finetune)
            model = self.load_model(self.a.finetune)
            log('model loaded, there are %i sets of params' %
                len(model.parameters_requires_grads()))
        else:
            model = self.load_model(None)
            log('model created from scratch, there are %i sets of params' %
                len(model.parameters_requires_grads()))

        if self.a.optimizer == "adadelta":
            optimizer_constructor = partial(
                torch.optim.Adadelta,
                params=model.parameters_requires_grads(),
                weight_decay=self.a.l2decay)
        elif self.a.optimizer == "adam":
            optimizer_constructor = partial(
                torch.optim.Adam,
                params=model.parameters_requires_grads(),
                weight_decay=self.a.l2decay)
        else:
            optimizer_constructor = partial(
                torch.optim.SGD,
                params=model.parameters_requires_grads(),
                weight_decay=self.a.l2decay,
                momentum=0.9)

        log('optimizer in use: %s' % str(self.a.optimizer))

        if not os.path.exists(self.a.out):
            os.mkdir(self.a.out)
        with open(os.path.join(self.a.out, "word.vec"), "wb") as f:
            pickle.dump(WordsField.vocab, f)
        with open(os.path.join(self.a.out, "pos.vec"), "wb") as f:
            pickle.dump(PosTagsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "entity.vec"), "wb") as f:
            pickle.dump(EntityLabelsField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "label.vec"), "wb") as f:
            pickle.dump(LabelField.vocab.stoi, f)
        with open(os.path.join(self.a.out, "role.vec"), "wb") as f:
            pickle.dump(EventsField.vocab.stoi, f)

        log('init complete\n')

        self.a.word_i2s = WordsField.vocab.itos
        self.a.label_i2s = LabelField.vocab.itos
        self.a.role_i2s = EventsField.vocab.itos
        writer = SummaryWriter(os.path.join(self.a.out, "exp"))
        self.a.writer = writer

        train(model=model,
              train_set=train_set,
              dev_set=dev_set,
              test_set=test_set,
              optimizer_constructor=optimizer_constructor,
              epochs=self.a.epochs,
              tester=tester,
              parser=self.a,
              other_testsets={
                  "dev 1/1": dev_set1,
                  "test 1/1": test_set1,
              })
        log('Done!')
Esempio n. 2
0
    def __new__(cls, *args, **kwargs):
        log('created %s with params %s' % (str(cls), str(args)))

        instance = super(Model, cls).__new__(cls)
        instance.__init__(*args, **kwargs)
        return instance
Esempio n. 3
0
                          ]]),
        torch.FloatTensor([
            1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
            1, 1
        ]), torch.Size([BATCH_SIZE, ET, SEQ_LEN,
                        SEQ_LEN])).to_dense().to(device)
    input = torch.randn(BATCH_SIZE, SEQ_LEN, D).to(device)
    label = torch.LongTensor([0, 1, 0, 1, 0, 1, 0, 1]).to(device)

    cc = GraphConvolution(in_features=D,
                          out_features=D,
                          edge_types=ET,
                          device=device,
                          use_bn=True)
    oo = BottledOrthogonalLinear(in_features=D, out_features=CLASSN).to(device)

    optimizer = Adadelta(list(cc.parameters()) + list(oo.parameters()))

    aloss = 1e9
    df = 1e9
    while df > 1e-7:
        output = oo(cc(input, adj)).view(BATCH_SIZE * SEQ_LEN, CLASSN)
        loss = F.cross_entropy(output, label)
        df = abs(aloss - loss.item())
        aloss = loss.item()
        loss.backward()
        optimizer.step()
        log(aloss)

    log(F.softmax(output), dim=2)