Пример #1
0
    def __predict(self, data, include_timing=False, eval_raw=False):
        """
        Uses the trained model to make predictions of individual batches (i.e. documents).

        :return: predictions and time taken for the ED step.
        """

        predictions = {items[0]["doc_name"]: [] for items in data}
        self.model.eval()

        timing = []

        for batch in data:  # each document is a minibatch

            start = time.time()

            token_ids = [
                m["context"][0] + m["context"][1]
                if len(m["context"][0]) + len(m["context"][1]) > 0 else
                [self.embeddings["word_voca"].unk_id] for m in batch
            ]
            s_ltoken_ids = [m["snd_ctx"][0] for m in batch]
            s_rtoken_ids = [m["snd_ctx"][1] for m in batch]
            s_mtoken_ids = [m["snd_ment"] for m in batch]

            entity_ids = Variable(
                torch.LongTensor([m["selected_cands"]["cands"]
                                  for m in batch]).to(self.device))
            p_e_m = Variable(
                torch.FloatTensor([
                    m["selected_cands"]["p_e_m"] for m in batch
                ]).to(self.device))
            entity_mask = Variable(
                torch.FloatTensor([m["selected_cands"]["mask"]
                                   for m in batch]).to(self.device))
            true_pos = Variable(
                torch.LongTensor([
                    m["selected_cands"]["true_pos"] for m in batch
                ]).to(self.device))

            token_ids, token_mask = utils.make_equal_len(
                token_ids, self.embeddings["word_voca"].unk_id)
            s_ltoken_ids, s_ltoken_mask = utils.make_equal_len(
                s_ltoken_ids,
                self.embeddings["snd_voca"].unk_id,
                to_right=False)
            s_rtoken_ids, s_rtoken_mask = utils.make_equal_len(
                s_rtoken_ids, self.embeddings["snd_voca"].unk_id)
            s_rtoken_ids = [l[::-1] for l in s_rtoken_ids]
            s_rtoken_mask = [l[::-1] for l in s_rtoken_mask]
            s_mtoken_ids, s_mtoken_mask = utils.make_equal_len(
                s_mtoken_ids, self.embeddings["snd_voca"].unk_id)

            token_ids = Variable(torch.LongTensor(token_ids).to(self.device))
            token_mask = Variable(
                torch.FloatTensor(token_mask).to(self.device))

            self.model.s_ltoken_ids = Variable(
                torch.LongTensor(s_ltoken_ids).to(self.device))
            self.model.s_ltoken_mask = Variable(
                torch.FloatTensor(s_ltoken_mask).to(self.device))
            self.model.s_rtoken_ids = Variable(
                torch.LongTensor(s_rtoken_ids).to(self.device))
            self.model.s_rtoken_mask = Variable(
                torch.FloatTensor(s_rtoken_mask).to(self.device))
            self.model.s_mtoken_ids = Variable(
                torch.LongTensor(s_mtoken_ids).to(self.device))
            self.model.s_mtoken_mask = Variable(
                torch.FloatTensor(s_mtoken_mask).to(self.device))

            scores, ent_scores = self.model.forward(
                token_ids,
                token_mask,
                entity_ids,
                entity_mask,
                p_e_m,
                self.embeddings,
                gold=true_pos.view(-1, 1),
            )
            pred_ids = torch.argmax(scores, axis=1)
            scores = scores.cpu().data.numpy()

            confidence_scores = self.__compute_confidence(scores, pred_ids)
            pred_ids = np.argmax(scores, axis=1)

            if not eval_raw:
                pred_entities = [
                    m["selected_cands"]["named_cands"][i]
                    if m["selected_cands"]["mask"][i] == 1 else
                    (m["selected_cands"]["named_cands"][0]
                     if m["selected_cands"]["mask"][0] == 1 else "NIL")
                    for (i, m) in zip(pred_ids, batch)
                ]
                doc_names = [m["doc_name"] for m in batch]

                for dname, entity in zip(doc_names, pred_entities):
                    predictions[dname].append({"pred": (entity, 0.0)})

            else:
                pred_entities = [[
                    m["selected_cands"]["named_cands"][i],
                    m["raw"]["mention"],
                    m["selected_cands"]["named_cands"],
                    s,
                    cs,
                    m["selected_cands"]["mask"],
                ] if m["selected_cands"]["mask"][i] == 1 else ([
                    m["selected_cands"]["named_cands"][0],
                    m["raw"]["mention"],
                    m["selected_cands"]["named_cands"],
                    s,
                    cs,
                    m["selected_cands"]["mask"],
                ] if m["selected_cands"]["mask"][0] == 1 else [
                    "NIL",
                    m["raw"]["mention"],
                    m["selected_cands"]["named_cands"],
                    s,
                    cs,
                    m["selected_cands"]["mask"],
                ]) for (i, m, s,
                        cs) in zip(pred_ids, batch, scores, confidence_scores)]
                doc_names = [m["doc_name"] for m in batch]

                for dname, entity in zip(doc_names, pred_entities):
                    if entity[0] != "NIL":
                        predictions[dname].append({
                            "mention":
                            entity[1],
                            "prediction":
                            entity[0],
                            "candidates":
                            entity[2],
                            "conf_ed":
                            entity[4],
                            "scores":
                            list([str(x) for x in entity[3]]),
                        })

                    else:
                        predictions[dname].append({
                            "mention": entity[1],
                            "prediction": entity[0],
                            "candidates": entity[2],
                            "scores": [],
                        })

            timing.append(time.time() - start)
        if include_timing:
            return predictions, timing
        else:
            return predictions
Пример #2
0
    def train(self, org_train_dataset, org_dev_datasets):
        """
        Responsible for training the ED model.

        :return: -
        """

        train_dataset = self.get_data_items(org_train_dataset,
                                            "train",
                                            predict=False)
        dev_datasets = []
        for dname, data in org_dev_datasets.items():
            dev_datasets.append(
                (dname, self.get_data_items(data, dname, predict=True)))

        print("Creating optimizer")
        optimizer = optim.Adam(
            [p for p in self.model.parameters() if p.requires_grad],
            lr=self.config["learning_rate"],
        )
        best_f1 = -1
        not_better_count = 0
        eval_after_n_epochs = self.config["eval_after_n_epochs"]

        for e in range(self.config["n_epochs"]):
            shuffle(train_dataset)

            total_loss = 0
            for dc, batch in enumerate(
                    train_dataset):  # each document is a minibatch
                self.model.train()
                optimizer.zero_grad()

                # convert data items to pytorch inputs
                token_ids = [
                    m["context"][0] + m["context"][1]
                    if len(m["context"][0]) + len(m["context"][1]) > 0 else
                    [self.embeddings["word_voca"].unk_id] for m in batch
                ]
                s_ltoken_ids = [m["snd_ctx"][0] for m in batch]
                s_rtoken_ids = [m["snd_ctx"][1] for m in batch]
                s_mtoken_ids = [m["snd_ment"] for m in batch]

                entity_ids = Variable(
                    torch.LongTensor([
                        m["selected_cands"]["cands"] for m in batch
                    ]).to(self.device))
                true_pos = Variable(
                    torch.LongTensor([
                        m["selected_cands"]["true_pos"] for m in batch
                    ]).to(self.device))
                p_e_m = Variable(
                    torch.FloatTensor([
                        m["selected_cands"]["p_e_m"] for m in batch
                    ]).to(self.device))
                entity_mask = Variable(
                    torch.FloatTensor([
                        m["selected_cands"]["mask"] for m in batch
                    ]).to(self.device))

                token_ids, token_mask = utils.make_equal_len(
                    token_ids, self.embeddings["word_voca"].unk_id)
                s_ltoken_ids, s_ltoken_mask = utils.make_equal_len(
                    s_ltoken_ids,
                    self.embeddings["snd_voca"].unk_id,
                    to_right=False)
                s_rtoken_ids, s_rtoken_mask = utils.make_equal_len(
                    s_rtoken_ids, self.embeddings["snd_voca"].unk_id)
                s_rtoken_ids = [l[::-1] for l in s_rtoken_ids]
                s_rtoken_mask = [l[::-1] for l in s_rtoken_mask]
                s_mtoken_ids, s_mtoken_mask = utils.make_equal_len(
                    s_mtoken_ids, self.embeddings["snd_voca"].unk_id)

                token_ids = Variable(
                    torch.LongTensor(token_ids).to(self.device))
                token_mask = Variable(
                    torch.FloatTensor(token_mask).to(self.device))

                # too ugly but too lazy to fix it
                self.model.s_ltoken_ids = Variable(
                    torch.LongTensor(s_ltoken_ids).to(self.device))
                self.model.s_ltoken_mask = Variable(
                    torch.FloatTensor(s_ltoken_mask).to(self.device))
                self.model.s_rtoken_ids = Variable(
                    torch.LongTensor(s_rtoken_ids).to(self.device))
                self.model.s_rtoken_mask = Variable(
                    torch.FloatTensor(s_rtoken_mask).to(self.device))
                self.model.s_mtoken_ids = Variable(
                    torch.LongTensor(s_mtoken_ids).to(self.device))
                self.model.s_mtoken_mask = Variable(
                    torch.FloatTensor(s_mtoken_mask).to(self.device))

                scores, ent_scores = self.model.forward(
                    token_ids,
                    token_mask,
                    entity_ids,
                    entity_mask,
                    p_e_m,
                    self.embeddings,
                    gold=true_pos.view(-1, 1),
                )
                loss = self.model.loss(scores, true_pos)
                # loss = self.model.prob_loss(scores, true_pos)
                loss.backward()
                optimizer.step()
                self.model.regularize(max_norm=100)

                loss = loss.cpu().data.numpy()
                total_loss += loss
                print(
                    "epoch",
                    e,
                    "%0.2f%%" % (dc / len(train_dataset) * 100),
                    loss,
                    end="\r",
                )

            print("epoch", e, "total loss", total_loss,
                  total_loss / len(train_dataset))

            if (e + 1) % eval_after_n_epochs == 0:
                dev_f1 = 0
                for dname, data in dev_datasets:
                    predictions = self.__predict(data)
                    f1, recall, precision, _ = self.__eval(
                        org_dev_datasets[dname], predictions)
                    print(
                        dname,
                        utils.tokgreen(
                            "Micro F1: {}, Recall: {}, Precision: {}".format(
                                f1, recall, precision)),
                    )

                    if dname == "aida_testA":
                        dev_f1 = f1

                if (self.config["learning_rate"] == 1e-4
                        and dev_f1 >= self.config["dev_f1_change_lr"]):
                    eval_after_n_epochs = 2
                    best_f1 = dev_f1
                    not_better_count = 0

                    self.config["learning_rate"] = 1e-5
                    print("change learning rate to",
                          self.config["learning_rate"])
                    for param_group in optimizer.param_groups:
                        param_group["lr"] = self.config["learning_rate"]

                if dev_f1 < best_f1:
                    not_better_count += 1
                    print("Not improving", not_better_count)
                else:
                    not_better_count = 0
                    best_f1 = dev_f1
                    print("save model to", self.config["model_path"])
                    self.__save(self.config["model_path"])

                if not_better_count == self.config["n_not_inc"]:
                    break