Exemple #1
0
    def forward(self, x, adj, length=None):
        batch_size, node_num, feature_dim = x.shape
        h = to_gpu(Variable(torch.from_numpy(x), requires_grad=False)).float()

        length_mask = None
        if length is not None:
            lengths_var = to_gpu(
                Variable(torch.from_numpy(length),
                         requires_grad=False)).long()
            # batch_size * node_num
            length_mask = sequence_mask(lengths_var, node_num)

            class_mask = length_mask.unsqueeze(2).expand(
                batch_size, node_num, 2)
            class_mask = class_mask.float()
        # adj: batch * node_num * node_num

        adj = to_gpu(Variable(torch.from_numpy(adj),
                              requires_grad=False)).float()
        h = self.gc_layer(h, adj, mask=length_mask)
        # h: batch * node_num * hidden
        if self.class_ln:
            h = self.ln_inp(h)
        h = F.dropout(h, self.drop_out_rate, training=self.training)
        output = self.classifer(h)
        # batch_size * node_num * self._num_class
        output = masked_softmax(output, mask=class_mask)
        return output
    def getNeighborMask(self, num_mentions, dim):
        batch_size, cand_num = num_mentions.shape
        # batch * cand_num
        margin_col = to_gpu(torch.zeros(1, cand_num))

        right_mask = to_gpu(torch.from_numpy(num_mentions)).float()
        left_mask = torch.cat([margin_col, right_mask[:-1, :]], dim=0)

        # (batch * cand_num) * dim
        right_mask_expand = right_mask.view(-1).unsqueeze(1).expand(
            batch_size * cand_num, dim)
        left_mask_expand = left_mask.view(-1).unsqueeze(1).expand(
            batch_size * cand_num, dim)
        return left_mask_expand, right_mask_expand
 def getCandidateEmbedding(self, candidates, candidates_sense=None):
     candidates = to_gpu(
         Variable(torch.from_numpy(candidates),
                  volatile=not self.training)).long()
     cand_entity_emb = self.run_embed(candidates, 1)
     cand_sense_emb = None
     cand_mu_emb = None
     if candidates_sense is not None and self._has_sense:
         candidates_sense = to_gpu(
             Variable(torch.from_numpy(candidates_sense),
                      volatile=not self.training)).long()
         cand_sense_emb = self.run_embed(candidates_sense, 2)
         cand_mu_emb = self.run_embed(candidates_sense, 3)
     return cand_entity_emb, cand_sense_emb, cand_mu_emb
Exemple #4
0
    def getCandidateSimilarity(self,
                               embeddings,
                               candidate_embeddings,
                               default_sims=None):
        cand_entity_emb, cand_sense_emb, cand_mu_emb = candidate_embeddings
        entity_emb, sense_emb, mu_emb = embeddings

        if default_sims is None:
            batch_size, _ = entity_emb.size()
            default_sims = to_gpu(
                Variable(torch.FloatTensor([DEFAULT_SIM] *
                                           batch_size).unsqueeze(1),
                         requires_grad=False))

        cand_entity_emb_expand = cand_entity_emb.unsqueeze(1)
        sim1 = torch.bmm(cand_entity_emb_expand,
                         entity_emb.unsqueeze(2)).squeeze(2)

        sim2 = DEFAULT_SIM
        sim3 = DEFAULT_SIM
        if cand_sense_emb is not None and cand_mu_emb is not None:
            cand_sense_emb_expand = cand_sense_emb.unsqueeze(1)
            cand_mu_emb_expand = cand_mu_emb.unsqueeze(1)
            sim2 = torch.bmm(cand_sense_emb_expand,
                             sense_emb.unsqueeze(2)).squeeze(2)
            sim3 = torch.bmm(cand_mu_emb_expand,
                             mu_emb.unsqueeze(2)).squeeze(2)

        return sim1, sim2, sim3
Exemple #5
0
    def getNeighEmb(self, mstr_emb, cand_num, neighbor_window, left_mask,
                    right_mask):
        margin_col = to_gpu(
            Variable(torch.zeros(cand_num, self._dim), requires_grad=False))
        # left_neighs: (batch_size*cand_num) * window * dim
        tmp_left_neigh_list = []
        tmp_left_neigh_list.append(
            self.leftMvNeigh(mstr_emb, cand_num, margin_col, left_mask))
        for i in range(neighbor_window - 1):
            tmp_left_neigh_list.append(
                self.leftMvNeigh(tmp_left_neigh_list[i], cand_num, margin_col,
                                 left_mask))
        for i, neigh in enumerate(tmp_left_neigh_list):
            tmp_left_neigh_list[i] = tmp_left_neigh_list[i].unsqueeze(1)
        left_neighs = torch.cat(tmp_left_neigh_list, dim=1)

        tmp_right_neigh_list = []
        tmp_right_neigh_list.append(
            self.rightMvNeigh(mstr_emb, cand_num, margin_col, right_mask))
        for i in range(neighbor_window - 1):
            tmp_right_neigh_list.append(
                self.rightMvNeigh(tmp_right_neigh_list[i], cand_num,
                                  margin_col, right_mask))
        for i, neigh in enumerate(tmp_right_neigh_list):
            tmp_right_neigh_list[i] = tmp_right_neigh_list[i].unsqueeze(1)
        right_neighs = torch.cat(tmp_right_neigh_list, dim=1)
        # neigh_emb: (batch_size*cand_num) * 2window * dim
        neigh_emb = torch.cat((left_neighs, right_neighs), dim=1)
        # neigh_emb: (batch_size*cand_num) * dim
        neigh_emb = torch.mean(neigh_emb, dim=1)
        return neigh_emb
    def getNeighCandidates(self, emb, window, num_mentions):
        batch_size, cand_num = num_mentions.shape
        _, dim = emb.size()
        left_mask, right_mask = self.getNeighborMask(num_mentions, dim)
        margin_col = to_gpu(torch.zeros(cand_num, dim))
        left_list = []
        # (batch * cand) * dim
        left_list.append(
            self.leftNeighbor(emb, cand_num, margin_col, left_mask))
        for i in range(window - 1):
            left_list.append(
                self.leftNeighbor(left_list[i], cand_num, margin_col,
                                  left_mask))
        for i in range(window):
            left_list[i] = self.getExpandNeighCandidates(
                left_list[i], batch_size, cand_num, dim)
        # (batch * cand) * (window*cand) * dim
        left_cands = torch.cat(left_list, dim=1)

        right_list = []
        right_list.append(
            self.rightNeighbor(emb, cand_num, margin_col, right_mask))
        for i in range(window - 1):
            right_list.append(
                self.rightNeighbor(right_list[i], cand_num, margin_col,
                                   right_mask))
        for i in range(window):
            right_list[i] = self.getExpandNeighCandidates(
                right_list[i], batch_size, cand_num, dim)
        # (batch * cand) * (window*cand) * dim
        right_cands = torch.cat(right_list, dim=1)

        # (batch * cand) * (cand_num*window*2) * dim
        neigh_cands = torch.cat((left_cands, right_cands), dim=1)
        return neigh_cands
    def getTokenEmbedding(self, tokens, candidate_embeddings=None):
        tokens = to_gpu(
            Variable(torch.from_numpy(tokens),
                     volatile=not self.training)).long()

        if candidate_embeddings is not None:
            cand_entity_emb, cand_sense_emb, cand_mu_emb = candidate_embeddings
            entity_emb = self.getEmbFeatures(tokens, q_emb=cand_entity_emb)
        else:
            entity_emb = self.getEmbFeatures(tokens)
        sense_emb = entity_emb
        mu_emb = entity_emb

        return entity_emb, sense_emb, mu_emb
    def getGraphSample(self,
                       e,
                       num_mentions,
                       entity_vocab,
                       id2wiki_vocab,
                       only_one=False):
        ent_label_vocab = dict([(entity_vocab[id], id2wiki_vocab[id])
                                for id in entity_vocab if id in id2wiki_vocab])
        ent_label_vocab[0] = 'PAD'
        ent_label_vocab[1] = 'UNK'

        batch_size, cand_num = e.shape
        # graph, (batch * cand) * (cand_num*window*2+1)
        adj = self._adj.data
        # neighbors, (batch * cand) * (cand_num*window*2)
        e_var = to_gpu(
            Variable(torch.from_numpy(e).view(-1).unsqueeze(1),
                     requires_grad=False).float())
        neighbors = self.getNeighCandidates(e_var, self._neighbor_cand_window,
                                            num_mentions).data.squeeze()
        c_idx = -1
        docs = []
        doc_edges = []
        is_doc_end = False
        for i in range(batch_size):
            for j in range(cand_num):
                c_idx += 1
                if e[i][j] in [0, 1]: continue
                label = ent_label_vocab[e[i][j]]
                # edges
                edges = adj[c_idx]
                nodes = neighbors[c_idx]
                tmp_len = len(edges) - 1
                for k in range(tmp_len):
                    if edges[k] > 0 and nodes[k] not in [0, 1]:
                        n_label = ent_label_vocab[nodes[k]]
                        doc_edges.append([label, n_label, edges[k]])
                # doc
                if num_mentions[i][j] == 0:
                    is_doc_end = True
            if is_doc_end:
                is_doc_end = False
                doc_line = "Graph: \n" + "\n".join([
                    "{}<-{}->{}".format(edge[0], edge[2], edge[1])
                    for edge in doc_edges
                ]) + '\n'
                docs.append(doc_line)
                if only_one: return docs
                del doc_edges[:]
        return docs
    def getNeighborMentionEmbeddings(self, ment_emb, neighbor_window,
                                     num_mentions):
        batch_size, cand_num = num_mentions.shape
        _, dim = ment_emb.size()
        left_mask, right_mask = self.getNeighborMask(num_mentions, dim)
        margin_col = to_gpu(torch.zeros(cand_num, dim))

        neibor_ment_entity_emb = self.getNeighborMentionEmbeddingsForCandidate(
            ment_emb, margin_col, cand_num, neighbor_window, left_mask,
            right_mask)
        neibor_ment_entity_emb = Variable(neibor_ment_entity_emb,
                                          requires_grad=False)
        neibor_ment_sense_emb = neibor_ment_entity_emb
        neibor_ment_mu_emb = neibor_ment_entity_emb
        return neibor_ment_entity_emb, neibor_ment_sense_emb, neibor_ment_mu_emb
Exemple #10
0
    def forward(self,
                contexts1,
                base_feature,
                candidates,
                mention_tokens,
                contexts2=None,
                candidates_sense=None,
                num_mentions=None,
                length=None):
        batch_size, cand_num, _ = base_feature.shape
        features = []
        # to gpu
        base_feature = to_gpu(
            Variable(torch.from_numpy(base_feature[:, :, -1]),
                     requires_grad=False)).float()

        return base_feature.squeeze()
    def buildGraph(self, cand_emb, window, num_mentions, thred=0.0):
        batch_size, cand_num = num_mentions.shape
        # (batch * cand) * (cand_num*window*2) * dim
        neigh_cands = self.getNeighCandidates(cand_emb, window, num_mentions)

        # (batch * cand) * (cand_num*window*2) * dim
        cand_emb_expand = cand_emb.unsqueeze(1).expand(batch_size * cand_num,
                                                       2 * window * cand_num,
                                                       self._dim)
        # (batch * cand) * (cand_num*window*2)
        adj = torch.clamp(
            F.cosine_similarity(cand_emb_expand, neigh_cands, dim=2), thred, 1)
        if thred > 0.0:
            adj[adj <= thred] = 0.0
        # add self connection
        margin_col = to_gpu(torch.ones(batch_size * cand_num, 1))
        # size: (batch * cand) * (cand_num*window*2+1)
        adj = torch.cat((adj * self._rho, margin_col), dim=1)
        # normalize
        adj = Variable(F.normalize(adj, p=1, dim=1), requires_grad=False)
        return adj
Exemple #12
0
    def getNeighborMentionEmbeddings(self, mention_embeddings, neighbor_window,
                                     num_mentions):
        batch_size, cand_num = num_mentions.shape
        entity_emb, sense_emb, mu_emb = mention_embeddings
        _, dim = entity_emb.size()
        left_mask, right_mask = self.getNeighborMask(num_mentions, dim)
        margin_col = to_gpu(
            Variable(torch.zeros(cand_num, dim), requires_grad=False))

        neibor_ment_entity_emb = self.getNeighborMentionEmbeddingsForCandidate(
            entity_emb, margin_col, cand_num, neighbor_window, left_mask,
            right_mask)
        neibor_ment_sense_emb = None
        neibor_ment_mu_emb = None
        if sense_emb is not None:
            neibor_ment_sense_emb = self.getNeighborMentionEmbeddingsForCandidate(
                sense_emb, margin_col, cand_num, neighbor_window, left_mask,
                right_mask)
        if mu_emb is not None:
            neibor_ment_mu_emb = self.getNeighborMentionEmbeddingsForCandidate(
                mu_emb, margin_col, cand_num, neighbor_window, left_mask,
                right_mask)
        return neibor_ment_entity_emb, neibor_ment_sense_emb, neibor_ment_mu_emb
Exemple #13
0
    def forward(self,
                contexts1,
                base_feature,
                candidates,
                m_strs,
                contexts2=None,
                candidates_sense=None,
                num_mentions=None,
                length=None):
        batch_size, cand_num, _ = base_feature.shape
        # to gpu
        base_feature = to_gpu(Variable(torch.from_numpy(base_feature))).float()
        contexts1 = to_gpu(Variable(torch.from_numpy(contexts1))).long()
        candidates = to_gpu(Variable(torch.from_numpy(candidates))).long()
        m_strs = to_gpu(Variable(torch.from_numpy(m_strs))).long()

        # candidate mask
        if length is not None:
            lengths_var = to_gpu(
                Variable(torch.from_numpy(length),
                         requires_grad=False)).long()
            # batch_size * cand_num
            length_mask = sequence_mask(lengths_var, cand_num).float()
        # mention context mask
        has_neighbors = False
        if self._neighbor_window > 0 and num_mentions is not None:
            # batch * cand
            margin_col = to_gpu(
                Variable(torch.zeros(1, cand_num), requires_grad=False))
            right_neigh_mask = to_gpu(
                Variable(torch.from_numpy(num_mentions),
                         requires_grad=False)).float()
            left_neigh_mask = torch.cat([margin_col, right_neigh_mask[:-1, :]],
                                        dim=0)
            right_neigh_mask_expand = right_neigh_mask.view(-1).unsqueeze(
                1).expand(batch_size * cand_num, self._dim)
            left_neigh_mask_expand = left_neigh_mask.view(-1).unsqueeze(
                1).expand(batch_size * cand_num, self._dim)
            has_neighbors = True

        has_context2 = False
        if contexts2 is not None and self._use_contexts2:
            contexts2 = to_gpu(Variable(torch.from_numpy(contexts2))).long()
            has_context2 = True

        has_sense = False
        if candidates_sense is not None and self._has_sense:
            candidates_sense = to_gpu(
                Variable(torch.from_numpy(candidates_sense))).long()
            has_sense = True

        # get emb, (batch * cand) * dim
        cand_entity_emb = self.run_embed(candidates, 1)
        f1_entity_emb = self.getEmbFeatures(contexts1, q_emb=cand_entity_emb)

        if has_sense:
            cand_sense_emb = self.run_embed(candidates_sense, 2)
            cand_mu_emb = self.run_embed(candidates_sense, 3)
            f1_sense_emb = self.getEmbFeatures(contexts1, q_emb=cand_sense_emb)
            f1_mu_emb = self.getEmbFeatures(contexts1, q_emb=cand_mu_emb)

        if has_context2:
            f2_entity_emb = self.getEmbFeatures(contexts2,
                                                q_emb=cand_entity_emb)
            if has_sense:
                f2_sense_emb = self.getEmbFeatures(contexts2,
                                                   q_emb=cand_sense_emb)
                f2_mu_emb = self.getEmbFeatures(contexts2, q_emb=cand_mu_emb)

        # get contextual similarity, (batch * cand) * contextual_sim
        cand_entity_emb_expand = cand_entity_emb.unsqueeze(1)
        if has_sense:
            cand_sense_emb_expand = cand_sense_emb.unsqueeze(1)
            cand_mu_emb_expand = cand_mu_emb.unsqueeze(1)

        # get mention string similarity
        ms_entity_emb = self.getEmbFeatures(m_strs, q_emb=cand_entity_emb)
        if has_sense:
            ms_sense_emb = self.getEmbFeatures(m_strs, q_emb=cand_sense_emb)
            ms_mu_emb = self.getEmbFeatures(m_strs, q_emb=cand_mu_emb)

        m_sim1 = torch.bmm(cand_entity_emb_expand,
                           ms_entity_emb.unsqueeze(2)).squeeze(2)
        if has_sense:
            m_sim2 = torch.bmm(cand_sense_emb_expand,
                               ms_sense_emb.unsqueeze(2)).squeeze(2)
            m_sim3 = torch.bmm(cand_mu_emb_expand,
                               ms_mu_emb.unsqueeze(2)).squeeze(2)

        if has_neighbors:
            # (batch * cand_num) * dim
            neigh_entity_emb = self.getNeighEmb(ms_entity_emb, cand_num,
                                                self._neighbor_window,
                                                left_neigh_mask_expand,
                                                right_neigh_mask_expand)
            n_sim1 = torch.bmm(cand_entity_emb_expand,
                               neigh_entity_emb.unsqueeze(2)).squeeze(2)
            if has_sense:
                neigh_sense_emb = self.getNeighEmb(ms_sense_emb, cand_num,
                                                   self._neighbor_window,
                                                   left_neigh_mask_expand,
                                                   right_neigh_mask_expand)
                n_sim2 = torch.bmm(cand_sense_emb_expand,
                                   neigh_sense_emb.unsqueeze(2)).squeeze(2)
                neigh_mu_emb = self.getNeighEmb(ms_mu_emb, cand_num,
                                                self._neighbor_window,
                                                left_neigh_mask_expand,
                                                right_neigh_mask_expand)
                n_sim3 = torch.bmm(cand_mu_emb_expand,
                                   neigh_mu_emb.unsqueeze(2)).squeeze(2)

        # entity: context1
        sim1 = torch.bmm(cand_entity_emb_expand,
                         f1_entity_emb.unsqueeze(2)).squeeze(2)
        if has_sense:
            # sense : context1
            sim2 = torch.bmm(cand_sense_emb_expand,
                             f1_sense_emb.unsqueeze(2)).squeeze(2)
            # mu : context1
            sim3 = torch.bmm(cand_mu_emb_expand,
                             f1_mu_emb.unsqueeze(2)).squeeze(2)

        # entity: context2
        if has_context2:
            sim4 = torch.bmm(cand_entity_emb_expand,
                             f2_entity_emb.unsqueeze(2)).squeeze(2)
            if has_sense:
                # sense : context2
                sim5 = torch.bmm(cand_sense_emb_expand,
                                 f2_sense_emb.unsqueeze(2)).squeeze(2)
                # mu : context2
                sim6 = torch.bmm(cand_mu_emb_expand,
                                 f2_mu_emb.unsqueeze(2)).squeeze(2)

        # feature vec : batch * cand * feature_dim
        # feature dim: base_dim + 2*dim + 2 + 1(if has entity) +
        # (2+word_dim)(if has contexts) + 1(if has context2 and has entity)
        base_feature = base_feature.view(batch_size * cand_num, -1)
        h = torch.cat(
            (base_feature, cand_entity_emb, f1_entity_emb, sim1, m_sim1),
            dim=1)
        if has_sense:
            h = torch.cat((h, sim2, sim3, m_sim2, m_sim3), dim=1)

        if has_context2:
            h = torch.cat((h, sim4, f2_entity_emb), dim=1)
            if has_sense:
                h = torch.cat((h, sim5, sim6), dim=1)
        if has_neighbors:
            h = torch.cat((h, n_sim1), dim=1)
            if has_sense:
                h = torch.cat((h, n_sim2, n_sim3), dim=1)

        h = self.mlp_classifier(h, length=length_mask.view(-1))
        # reshape, batch_size * cand_num
        h = h.view(batch_size, -1)

        output = masked_softmax(h, mask=length_mask)
        return output
Exemple #14
0
    def forward(self,
                contexts1,
                base_feature,
                candidates,
                mention_tokens,
                contexts2=None,
                candidates_sense=None,
                num_mentions=None,
                length=None):
        batch_size, cand_num, _ = base_feature.shape
        features = []
        # to gpu
        base_feature = to_gpu(
            Variable(torch.from_numpy(base_feature),
                     requires_grad=False)).float()
        base_feature = base_feature.view(batch_size * cand_num, -1)
        features.append(base_feature)

        # candidate mask
        length_mask = None
        if length is not None:
            lengths_var = to_gpu(
                Variable(torch.from_numpy(length),
                         requires_grad=False)).long()
            # batch_size * cand_num
            length_mask = sequence_mask(lengths_var, cand_num).float()

        # get emb, (batch * cand) * dim
        candidate_embeddings = self.getCandidateEmbedding(
            candidates, candidates_sense)
        cand_emb1, cand_emb2, cand_emb3 = candidate_embeddings

        # get context emb
        context1_emb = self.getTokenEmbedding(
            contexts1,
            candidate_embeddings=candidate_embeddings
            if self._use_att else None)

        # get contextual similarity, (batch * cand) * contextual_sim
        con1_sims = self.getCandidateSimilarity(context1_emb,
                                                candidate_embeddings)
        features.extend(con1_sims)

        con2_sims = DEFAULT_SIM, DEFAULT_SIM, DEFAULT_SIM
        con2_emb_cand1 = None
        if self._use_contexts2 and contexts2 is not None:
            context2_emb = self.getTokenEmbedding(
                contexts2,
                candidate_embeddings=candidate_embeddings
                if self._use_att else None)
            con2_emb_cand1, con2_emb_cand2, con2_emb_cand3 = context2_emb
            # get contextual similarity, (batch * cand) * contextual_sim
            con2_sims = self.getCandidateSimilarity(context2_emb,
                                                    candidate_embeddings)
        features.extend(con2_sims)

        # get mention string similarity, todo: no att
        ment_embs = self.getTokenEmbedding(
            mention_tokens, candidate_embeddings=candidate_embeddings)
        mention_sims = self.getCandidateSimilarity(ment_embs,
                                                   candidate_embeddings)
        features.extend(mention_sims)

        # neibor mention string similarity
        neigh_ment_sims = DEFAULT_SIM, DEFAULT_SIM, DEFAULT_SIM
        if self._neighbor_window > 0 and num_mentions is not None:
            # (batch * cand_num) * dim
            neigh_ment_embs = self.getNeighborMentionEmbeddings(
                ment_embs, self._neighbor_window, num_mentions)
            neigh_ment_sims = self.getCandidateSimilarity(
                neigh_ment_embs, candidate_embeddings)
        features.extend(neigh_ment_sims)

        # neighbor candidates
        # (batch * cand) * 1 * (cand_num*window*2+1)
        self._adj = self.buildGraph(cand_emb1,
                                    self._neighbor_window,
                                    num_mentions,
                                    thred=self._thred).unsqueeze(1)
        # feature vec : (batch * cand) * feature_dim
        h = torch.cat(features, dim=1)
        if self._use_embedding_feature:
            con1_emb_cand1, con1_emb_cand2, con1_emb_cand3 = context1_emb
            h = torch.cat((h, cand_emb1, con1_emb_cand1), dim=1)
            if con2_emb_cand1 is not None:
                h = torch.cat((h, con2_emb_cand1), dim=1)
        if self._gc_ln:
            h = self.ln_inp(h)
        for i in range(self._num_layers):
            w = getattr(self, 'w{}'.format(i))
            b = getattr(self, 'b{}'.format(i))
            dim = getattr(self, 'd{}'.format(i))
            h = h.matmul(w)
            # (batch_size * cand_num) * (2*window*cand_num+1) * f_dim
            h = self.getExpandFeature(h, self._neighbor_window, num_mentions)
            h = torch.bmm(self._adj, h).squeeze(1)
            if b is not None: h = h + b
            # h: (batch_size * cand_num) * feature_dim
            if length_mask is not None:
                mask = length_mask.view(-1).unsqueeze(1).expand(
                    batch_size * cand_num, dim)
                h = h * mask
            h = F.relu(h)
        h = F.dropout(h, self._dropout_rate, training=self.training)

        h = h.matmul(self.gc_classifier_w)
        # (batch_size * cand_num) * (2*window*cand_num+1) * f_dim
        h = self.getExpandFeature(h, self._neighbor_window, num_mentions)
        h = torch.bmm(self._adj, h).squeeze(1)
        if self.gc_classifier_b is not None: h = h + self.gc_classifier_b

        # reshape, batch_size * cand_num
        h = h.squeeze().view(batch_size, -1)
        output = masked_softmax(h, mask=length_mask)
        return output
Exemple #15
0
def train_loop(FLAGS, model, trainer, training_data_iter, eval_iterators,
               logger, vocabulary):
    # Accumulate useful statistics.
    A = Accumulator(maxlen=FLAGS.deque_length)

    # Train.
    logger.Log("Training.")

    # New Training Loop
    progress_bar = SimpleProgressBar(msg="Training",
                                     bar_length=60,
                                     enabled=FLAGS.show_progress_bar)
    progress_bar.step(i=0, total=FLAGS.statistics_interval_steps)

    log_entry = pb.NcelEntry()
    for _ in range(trainer.step, FLAGS.training_steps):
        if (trainer.step -
                trainer.best_dev_step) > FLAGS.early_stopping_steps_to_wait:
            logger.Log('No improvement after ' +
                       str(FLAGS.early_stopping_steps_to_wait) +
                       ' steps. Stopping training.')
            break

        # set model in training mode
        model.train()

        log_entry.Clear()
        log_entry.step = trainer.step
        should_log = False

        start = time.time()
        doc_batch = next(training_data_iter)
        batch = get_batch(doc_batch,
                          FLAGS.local_context_window,
                          use_lr_context=FLAGS.use_lr_context,
                          split_by_sent=FLAGS.split_by_sent)
        base, context1, context2, m_strs, cids, cids_sense, num_candidates, num_mentions, y = batch
        # check training data
        # inspectBatch(batch, vocabulary, doc_batch)

        total_candidates = num_candidates.sum()

        # Reset cached gradients.
        trainer.optimizer_zero_grad()

        # Run model. output: batch_size * cand_num
        output = model(context1,
                       base,
                       cids,
                       m_strs,
                       contexts2=context2,
                       candidates_sense=cids_sense,
                       num_mentions=num_mentions,
                       length=num_candidates)

        target = torch.from_numpy(y).long()
        # Calculate accuracy.
        total_mentions, actual_mentions, actual_correct = \
            ComputeAccuracy(output.data, target, doc_batch)

        # Calculate loss.
        loss = nn.CrossEntropyLoss()(output,
                                     to_gpu(
                                         Variable(target,
                                                  requires_grad=False)))
        # loss = nn.MultiLabelMarginLoss()(output, to_gpu(Variable(target, volatile=False)))
        # Backward pass.
        loss.backward()
        # Hard Gradient Clipping
        nn.utils.clip_grad_norm([
            param for name, param in model.named_parameters() if name not in [
                "word_embed.embed.weight", "entity_embed.embed.weight",
                "sense_embed.embed.weight", "mu_embed.embed.weight"
            ]
        ], FLAGS.clipping_max_value)

        # Gradient descent step.
        trainer.optimizer_step()

        end = time.time()

        total_time = end - start

        doc_accs = [
            correct / float(actual_mentions[i])
            for i, correct in enumerate(actual_correct)
        ]

        A.add('mention_prec',
              sum(actual_correct) / float(sum(actual_mentions)))
        A.add('doc_prec', sum(doc_accs) / float(len(doc_accs)))
        A.add('total_candidates', total_candidates)
        A.add('total_time', total_time)

        if trainer.step % FLAGS.statistics_interval_steps == 0:
            A.add('total_cost', loss.data[0])
            stats(model, trainer, A, log_entry)
            should_log = True
            progress_bar.finish()

        if trainer.step > 0 and trainer.step % FLAGS.eval_interval_steps == 0:
            should_log = True
            # note: at most tow eval set due to training recording best
            eval_metrics = []
            for index, eval_set in enumerate(eval_iterators):
                eval_metrics.append(
                    evaluate(FLAGS,
                             model,
                             eval_set,
                             log_entry,
                             logger,
                             show_sample=FLAGS.show_sample,
                             vocabulary=vocabulary,
                             eval_index=index))
            trainer.new_accuracy(eval_metrics)
            progress_bar.reset()

        if trainer.step > FLAGS.ckpt_step and trainer.step % FLAGS.ckpt_interval_steps == 0:
            should_log = True
            trainer.checkpoint()

        if should_log:
            logger.LogEntry(log_entry)

        progress_bar.step(i=(trainer.step % FLAGS.statistics_interval_steps) +
                          1,
                          total=FLAGS.statistics_interval_steps)
    finalStats(trainer, logger)
    def forward(self,
                contexts1,
                base_feature,
                candidates,
                mention_tokens,
                contexts2=None,
                candidates_sense=None,
                num_mentions=None,
                length=None):
        batch_size, cand_num, _ = base_feature.shape
        features = []
        # to gpu
        base_feature = to_gpu(
            Variable(torch.from_numpy(base_feature),
                     requires_grad=False)).float()
        base_feature = base_feature.view(batch_size * cand_num, -1)
        features.append(base_feature)

        # candidate mask
        length_mask = None
        if length is not None:
            lengths_var = to_gpu(
                Variable(torch.from_numpy(length),
                         requires_grad=False)).long()
            # batch_size * cand_num
            length_mask = sequence_mask(lengths_var, cand_num).float()

        # get emb, (batch * cand) * dim
        candidate_embeddings = self.getCandidateEmbedding(
            candidates, candidates_sense)
        cand_emb1, cand_emb2, cand_emb3 = candidate_embeddings

        # get context emb
        context1_emb = self.getTokenEmbedding(
            contexts1,
            candidate_embeddings=candidate_embeddings
            if self._use_att else None)

        # get contextual similarity, (batch * cand) * contextual_sim
        con1_sims = self.getCandidateSimilarity(context1_emb,
                                                candidate_embeddings)
        features.extend(con1_sims)

        con2_sims = DEFAULT_SIM, DEFAULT_SIM, DEFAULT_SIM
        con2_emb_cand1 = None
        if self._use_contexts2 and contexts2 is not None:
            context2_emb = self.getTokenEmbedding(
                contexts2,
                candidate_embeddings=candidate_embeddings
                if self._use_att else None)
            con2_emb_cand1, con2_emb_cand2, con2_emb_cand3 = context2_emb
            # get contextual similarity, (batch * cand) * contextual_sim
            con2_sims = self.getCandidateSimilarity(context2_emb,
                                                    candidate_embeddings)
        features.extend(con2_sims)

        # get mention string similarity,
        # ment_embs = self.getTokenEmbedding(mention_tokens, candidate_embeddings=candidate_embeddings)
        ment_embs = self.getTokenEmbedding(mention_tokens)
        mention_sims = self.getCandidateSimilarity(ment_embs,
                                                   candidate_embeddings)
        features.extend(mention_sims)

        # neibor mention string similarity
        neigh_ment_sims = DEFAULT_SIM, DEFAULT_SIM, DEFAULT_SIM
        if self._neighbor_ment_window > 0 and num_mentions is not None:
            ment_entity_embs, _, _ = ment_embs
            # (batch * cand_num) * dim
            neigh_ment_embs = self.getNeighborMentionEmbeddings(
                ment_entity_embs.data, self._neighbor_ment_window,
                num_mentions)
            neigh_ment_sims = self.getCandidateSimilarity(
                neigh_ment_embs, candidate_embeddings)
        features.extend(neigh_ment_sims)

        # neighbor candidates
        # (batch * cand) * (cand_num*window*2+1)
        self._adj = self.buildGraph(cand_emb1.data,
                                    self._neighbor_cand_window,
                                    num_mentions,
                                    thred=self._thred).unsqueeze(1)
        # feature vec : (batch * cand) * feature_dim
        f_vec = torch.cat(features, dim=1)
        if self._use_embedding_feature:
            con1_emb_cand1, con1_emb_cand2, con1_emb_cand3 = context1_emb
            f_vec = torch.cat((f_vec, cand_emb1, con1_emb_cand1), dim=1)
            if con2_emb_cand1 is not None:
                f_vec = torch.cat((f_vec, con2_emb_cand1), dim=1)

        # mlp classify
        gc_input = self.mlp_classifier(
            f_vec, length=length_mask.view(-1)) * self._temperature

        if self._res_num > 0:
            # (batch_size * cand_num) * dim
            for i in range(self._res_num):
                l = getattr(self, 'l{}'.format(i))
                h = l(gc_input, self._adj, num_mentions, mask=length_mask)
                # skip connection
                sk_layer = getattr(self, 'sk{}'.format(i))
                if sk_layer is not None:
                    h = h + sk_layer(gc_input)
                else:
                    h = h + gc_input
                gc_input = h
        if self.classifier is not None:
            gc_input = self.classifier(gc_input)
        # reshape, batch_size * cand_num
        h = gc_input.squeeze().view(batch_size, -1)
        output = masked_softmax(h, mask=length_mask)
        return output
Exemple #17
0
    def forward(self,
                contexts1,
                base_feature,
                candidates,
                contexts2=None,
                candidates_entity=None,
                length=None):
        batch_size, cand_num, _ = base_feature.shape
        # to gpu
        base_feature = to_gpu(Variable(torch.from_numpy(base_feature))).float()
        contexts1 = to_gpu(Variable(torch.from_numpy(contexts1))).long()
        candidates = to_gpu(Variable(torch.from_numpy(candidates))).long()

        has_context2 = False
        if contexts2 is not None and self._use_contexts2:
            contexts2 = to_gpu(Variable(torch.from_numpy(contexts2))).long()
            has_context2 = True

        has_entity = False
        if candidates_entity is not None and self._use_entity:
            candidates_entity = to_gpu(
                Variable(torch.from_numpy(candidates_entity))).long()
            has_entity = True

        # get emb, (batch * cand) * dim
        cand_emb = self.run_embed(candidates, 1)
        cand_mu_emb = self.run_embed(candidates, 2)

        f1_sense_emb = self.getEmbFeatures(contexts1, q_emb=cand_emb)
        f1_mu_emb = self.getEmbFeatures(contexts1, q_emb=cand_mu_emb)
        if has_entity:
            cand_entity_emb = self.run_embed(candidates_entity, 3)
            f1_entity_emb = self.getEmbFeatures(contexts1,
                                                q_emb=cand_entity_emb)

        if has_context2:
            f2_sense_emb = self.getEmbFeatures(contexts2, q_emb=cand_emb)
            f2_mu_emb = self.getEmbFeatures(contexts2, q_emb=cand_mu_emb)
            if has_entity:
                f2_entity_emb = self.getEmbFeatures(contexts2,
                                                    q_emb=cand_entity_emb)

        # get contextual similarity, (batch * cand) * contextual_sim
        cand_emb_expand = cand_emb.unsqueeze(1)
        cand_mu_emb_expand = cand_mu_emb.unsqueeze(1)
        # sense : context1
        sim1 = torch.bmm(cand_emb_expand, f1_sense_emb.unsqueeze(2)).squeeze(2)
        # mu : context1
        sim2 = torch.bmm(cand_mu_emb_expand, f1_mu_emb.unsqueeze(2)).squeeze(2)
        # entity: context1
        if has_entity:
            cand_entity_emb_expand = cand_entity_emb.unsqueeze(1)
            sim3 = torch.bmm(cand_entity_emb_expand,
                             f1_entity_emb.unsqueeze(2)).squeeze(2)

        # sense : context2
        if has_context2:
            sim4 = torch.bmm(cand_emb_expand,
                             f2_sense_emb.unsqueeze(2)).squeeze(2)
            # mu : context2
            sim5 = torch.bmm(cand_mu_emb_expand,
                             f2_mu_emb.unsqueeze(2)).squeeze(2)
            # entity: context2
            if has_entity:
                sim6 = torch.bmm(cand_entity_emb_expand,
                                 f2_entity_emb.unsqueeze(2)).squeeze(2)

        # feature vec : batch * cand * feature_dim
        # feature dim: base_dim + sense_dim + word_dim + 2 + 1(if has entity) +
        # (2+word_dim)(if has contexts) + 1(if has context2 and has entity)
        base_feature = base_feature.view(batch_size * cand_num, -1)
        h = torch.cat((base_feature, cand_emb, f1_sense_emb, sim1, sim2),
                      dim=1)
        if has_entity:
            h = torch.cat((h, sim3), dim=1)
        if has_context2:
            h = torch.cat((h, sim4, sim5, f2_sense_emb), dim=1)
            if has_entity:
                h = torch.cat((h, sim6), dim=1)

        h = self.mlp_classifier(h)
        # reshape, batch_size * cand_num
        h = h.view(batch_size, -1)

        if length is not None:
            lengths_var = to_gpu(
                Variable(torch.from_numpy(length),
                         requires_grad=False)).long()
            # batch_size * cand_num
            length_mask = sequence_mask(lengths_var, cand_num).float()

        output = masked_softmax(h, mask=length_mask)
        return output