Ejemplo n.º 1
0
    def forward(self, _, heads, deps):
        '''heads.data: mypacked        amr_l x rel_dim
            deps.data: mydoublepacked     amr_l x amr_l x rel_dim
        '''
        heads_data = heads.data
        deps_data = deps.data

        head_bilinear_transformed = self.bilinear(
            heads_data)  #all_data x (    n_rel x inputsize)

        head_bias_unpacked = myunpack(self.head_bias(heads_data),
                                      heads.lengths)  #[len x n_rel]

        size = deps_data.size()
        dep_bias = self.dep_bias(deps_data.view(-1, size[-1])).view(
            size[0], size[1], -1)

        dep_bias_unpacked, length_pairs = mydoubleunpack(
            MyDoublePackedSequence(MyPackedSequence(dep_bias, deps[0][1]),
                                   deps[1], dep_bias))  #[len x n_rel]

        bilinear_unpacked = myunpack(head_bilinear_transformed, heads.lengths)

        deps_unpacked, length_pairs = mydoubleunpack(deps)
        output, l = self.bilinearForParallel(
            zip(bilinear_unpacked, deps_unpacked, head_bias_unpacked,
                dep_bias_unpacked), length_pairs)
        myscore_packed = mypack(output, l)

        #  prob_packed = MyPackedSequence(myscore_packed.data,l)
        return myscore_packed
Ejemplo n.º 2
0
 def root_score(self,mypackedhead):
     heads = myunpack(*mypackedhead)
     output = []
     for head in heads:
         score = self.root(head).squeeze(1)
         output.append(self.LogSoftmax(score))
     return output
Ejemplo n.º 3
0
    def forward(self, input, index, src_enc):
        """
        # input: relBatch: packed_gold_amr_lengths x n_feature,  AMR_CAT, AMR_LE, AMR_NER, AMR_SENSE, index of nodes,
        #  mypacked_seq[packed_batch_gold_amr_len x tgt_feature_dim]
        # index:  rel_index_batch:  list(batch, real_gold_amr_len), but content is the index of recatogrized amr index, is a mapping
        # src_enc: DoublePackedSequence(packed: packed_g_amr_len x re_amr_len x dim), length= re_lens)
        """
        assert isinstance(input, MyPackedSequence), input
        # lengths: real_gold_amr_lens
        # after unpack, input is packed_gold_amr_lengths x n_features
        input, lengths = input
        if self.alpha_dropout and self.training:
            input = data_dropout(input, self.alpha_dropout)

        psd_target_pos_embed = self.psd_target_pos_lut(input[:, PSD_POS])
        #psd_sense_embed = self.psd_sense_lut(input[:,PSD_SENSE])
        #psd_lemma_embed = self.lemma_lut(input[:,PSD_LE])

        #psd_emb = torch.cat([psd_target_pos_embed, psd_sense_embed,psd_lemma_embed],1)
        #psd_emb = torch.cat([psd_target_pos_embed, psd_sense_embed],1)
        psd_emb = torch.cat([psd_target_pos_embed], 1)

        # head_emb_t :  MyPackedSequence(data: packed_real_g_amr_l x dim), g_amr_l
        # dep_emb_t :  MyDoublePackedSequence(PackedSequenceLength(packed_real_g_amr_l x real_g_amr_l x dim), length_pairs)
        # length_pairs :(g_amr_l, g_amr_l)
        head_emb_t, dep_emb_t, length_pairs = self.getEmb(
            index, src_enc)  #packed, mydoublepacked

        head_emb = torch.cat([psd_emb, head_emb_t.data], 1)

        dep_psd_emb_t = myunpack(*MyPackedSequence(psd_emb, lengths))
        dep_psd_emb = [
            emb.unsqueeze(0).expand(emb.size(0), emb.size(0), emb.size(-1))
            for emb in dep_psd_emb_t
        ]

        mydouble_psd_emb = mydoublepack(dep_psd_emb, length_pairs)

        dep_emb = torch.cat([mydouble_psd_emb.data, dep_emb_t.data], -1)

        # emb_unpacked = myunpack(emb,lengths)
        assert head_emb.size(
            -1) == self.inputSize, "wrong head  size {}".format(
                head_emb.size())
        # head_packed :  MyPackedSequence(data: packed_real_g_amr_l x rel_dim), g_amr_l
        head_packed = MyPackedSequence(self.head(head_emb),
                                       lengths)  #  total,rel_dim
        head_psd_packed = MyPackedSequence(psd_emb, lengths)  #  total,rel_dim

        size = dep_emb.size()
        assert dep_emb.size(-1) == self.inputSize, "wrong dep size {}".format(
            dep_emb.size())
        dep = self.dep(dep_emb.view(-1, size[-1])).view(size[0], size[1], -1)

        # dep_emb_t :  MyDoublePackedSequence(PackedSequenceLength(packed_real_g_amr_l x real_g_amr_l x rel_dim), length_pairs)
        dep_packed = MyDoublePackedSequence(
            MyPackedSequence(dep, mydouble_psd_emb[0][1]), mydouble_psd_emb[1],
            dep)

        return head_psd_packed, head_packed, dep_packed  #,MyPackedSequence(emb,lengths)
Ejemplo n.º 4
0
    def posteriorIndictedEmb(embs, posterior):
        """
        embs: actually are bert encodings, padded_batch_src_len x bert_output_dim
        posterior : mypack(packed_gold_amr_len  x src_len), gold_amr_l
        """
        # after unpack, embs :[max_src_len x batch x bert_output_dim]
        embs, src_len = unpack(embs)

        if isinstance(posterior, MyPackedSequence):
            # after unpack, posterior: list(batch_size, real_gold_amr_len x src_len)
            posterior = myunpack(*posterior)
            # after tranpose, batch_size x padd_src_len x dim
            embs = embs.transpose(0, 1)
            out = []
            lengths = []
            # real gold amr length for each example in the batch
            amr_len = [len(p) for p in posterior]
            for i, emb in enumerate(embs):
                # expanded_emb: real_gold_amr_len x src_len x dim
                expanded_emb = emb.unsqueeze(0).expand([amr_len[i]] +
                                                       [i for i in emb.size()])
                indicator = posterior[i].unsqueeze(
                    2)  # real_gold_amr_len x src_len x 1
                out.append(
                    torch.cat([expanded_emb, indicator],
                              2))  # real_gold_amr_len x src_len x (dim+1)
                # out.append(expanded_emb)  # real_gold_amr_len x src_len x (dim+1)
                lengths = lengths + [src_len[i]] * amr_len[
                    i]  # lenghs =list(batch, real_gold_amr_len]), all stores src_lens
            data = torch.cat(
                out, dim=0)  # packed_real_gold_amr_len x src_len x (dim+1)

            return pack(data, lengths, batch_first=True), amr_len
        elif isinstance(posterior, list):
            # real alignments
            embs = embs.transpose(0, 1)
            src_l = embs.size(1)
            amr_len = [len(i) for i in posterior]
            out = []
            lengths = []
            for i, emb in enumerate(embs):
                amr_l = len(posterior[i])
                expanded_emb = emb.unsqueeze(0).expand(
                    [amr_l] + [i
                               for i in emb.size()])  # amr_len x src_len x dim
                indicator = emb.new_zeros((amr_l, src_l))
                scattered_indicator = indicator.scatter(
                    1, posterior[i].unsqueeze(1), 1.0)  # amr_len x src_len x 1
                out.append(
                    torch.cat([expanded_emb,
                               scattered_indicator.unsqueeze(2)],
                              2))  # amr_len x src_len x (dim+1)
                #out.append(expanded_emb)  # amr_len x src_len x (dim+1)
                lengths = lengths + [src_len[i]
                                     ] * amr_l  # batch x amr_len, src_len
            data = torch.cat(out,
                             dim=0)  # batch x amr_len, x src_len x (dim +1)

            return pack(data, lengths, batch_first=True), amr_len
Ejemplo n.º 5
0
 def getEmb(self,indexes,src_enc):
     head_emb,lengths = [],[]
     src_enc = myunpack(*src_enc)  #  pre_amr_l/src_l  x batch x dim
     for i, index in enumerate(indexes):
         enc = src_enc[i]  #src_l x  dim
         head_emb.append(enc[index])  #var(amr_l  x dim)
         lengths.append(len(index))
     return mypack(head_emb,lengths)
Ejemplo n.º 6
0
 def getEmb(self, indexes, src_enc):
     """
     # index:  rel_index_batch:  list(batch, real_gold_amr_len), but content is the index of recatogrized amr index, is a mapping
     # src_enc: is weighted root_src_enc, MyPackedSequence(data: packed_re_amr_len x txt_enc_size, lengtgs: re_amr_lens)
     """
     head_emb, lengths = [], []
     src_enc = myunpack(
         *src_enc)  #  list(batch, real_re_amr_len x src_enc_size)
     for i, index in enumerate(
             indexes):  # indexse, batch_size, real_gold_amr_lens
         enc = src_enc[i]  #real_re_amr_len x src_enc_size,
         head_emb.append(
             enc[index]
         )  #the content of index if real_re_amd_index, list(batch, real_gold_amr_len, dim)
         lengths.append(len(index))  #list(batch, real_gold_amr_len)
     return mypack(
         head_emb, lengths
     )  # MyPackedSequence(data: packed_gold_amr_len x dim, lengths: readl_gold_amr_len)
Ejemplo n.º 7
0
    def posteriorIndictedEmb(self,embs,posterior):
        #real alignment is sent in as list of index
        #variational relaxed posterior is sent in as MyPackedSequence

        #out   (batch x amr_len) x src_len x (dim+1)
        embs,src_len = unpack(embs)

        if isinstance(posterior,MyPackedSequence):
       #     print ("posterior is packed")
            posterior = myunpack(*posterior)
            embs = embs.transpose(0,1)
            out = []
            lengths = []
            amr_len = [len(p) for p in posterior]
            for i,emb in enumerate(embs):
                expanded_emb = emb.unsqueeze(0).expand([amr_len[i]]+[i for i in emb.size()]) # amr_len x src_len x dim
                indicator = posterior[i].unsqueeze(2)  # amr_len x src_len x 1
                out.append(torch.cat([expanded_emb,indicator],2))  # amr_len x src_len x (dim+1)
                lengths = lengths + [src_len[i]]*amr_len[i]
            data = torch.cat(out,dim=0)

            return pack(data,lengths,batch_first=True),amr_len
        elif isinstance(posterior,list):
            embs = embs.transpose(0,1)
            src_l = embs.size(1)
            amr_len = [len(i) for i in posterior]
            out = []
            lengths = []
            for i,emb in enumerate(embs):
                amr_l = len(posterior[i])
                expanded_emb = emb.unsqueeze(0).expand([amr_l]+[i for i in emb.size()]) # amr_len x src_len x dim
                indicator = emb.data.new(amr_l,src_l).zero_()
                indicator.scatter_(1, posterior[i].data.unsqueeze(1), 1.0) # amr_len x src_len x 1
                indicator = Variable(indicator.unsqueeze(2))
                out.append(torch.cat([expanded_emb,indicator],2))  # amr_len x src_len x (dim+1)
                lengths = lengths + [src_len[i]]*amr_l
            data = torch.cat(out,dim=0)

            return pack(data,lengths,batch_first=True),amr_len
Ejemplo n.º 8
0
    def forward(self, input, index,src_enc):
        assert isinstance(input, MyPackedSequence),input
        input,lengths = input
        if self.alpha and self.training:
            input = data_dropout(input,self.alpha)
        cat_embed = self.cat_lut(input[:,AMR_CAT])
        lemma_embed = self.lemma_lut(input[:,AMR_LE])

        amr_emb = torch.cat([cat_embed,lemma_embed],1)
    #    print (input,lengths)

        head_emb_t,dep_emb_t,length_pairs = self.getEmb(index,src_enc)  #packed, mydoublepacked


        head_emb = torch.cat([amr_emb,head_emb_t.data],1)

        dep_amr_emb_t = myunpack(*MyPackedSequence(amr_emb,lengths))
        dep_amr_emb = [ emb.unsqueeze(0).expand(emb.size(0),emb.size(0),emb.size(-1))      for emb in dep_amr_emb_t]

        mydouble_amr_emb = mydoublepack(dep_amr_emb,length_pairs)

    #    print ("rel_encoder",mydouble_amr_emb.data.size(),dep_emb_t.data.size())
        dep_emb = torch.cat([mydouble_amr_emb.data,dep_emb_t.data],-1)

       # emb_unpacked = myunpack(emb,lengths)

        head_packed = MyPackedSequence(self.head(head_emb),lengths) #  total,rel_dim
        head_amr_packed = MyPackedSequence(amr_emb,lengths) #  total,rel_dim

   #     print ("dep_emb",dep_emb.size())
        size = dep_emb.size()
        dep = self.dep(dep_emb.view(-1,size[-1])).view(size[0],size[1],-1)

        dep_packed  = MyDoublePackedSequence(MyPackedSequence(dep,mydouble_amr_emb[0][1]),mydouble_amr_emb[1],dep)

        return  head_amr_packed,head_packed,dep_packed  #,MyPackedSequence(emb,lengths)
    def relProbAndConToGraph(self,
                             srl_batch,
                             srl_prob,
                             roots,
                             appended,
                             get_sense=False,
                             set_wiki=False):
        #batch of id
        #max_indices len,batch, n_feature
        #srcBatch  batch source
        #out batch AMRuniversal
        rel_dict = self.dicts["rel_dict"]

        def get_uni_var(concepts, id):
            assert id < len(concepts), (id, concepts)
            uni = concepts[id]
            if uni.le in ["i", "it", "you", "they", "he", "she"
                          ] and uni.cat == Rule_Concept:
                return uni, Var(uni.le)
            le = uni.le
            if uni.cat != Rule_String:
                uni.le = uni.le.strip("/").strip(":")
                if ":" in uni.le or "/" in uni.le:
                    uni.cat = Rule_String
            if uni.le == "":
                return uni, Var(le + str(id))
            return uni, Var(uni.le[0] + str(id))

        def create_connected_graph(role_scores, concepts, root_id, dependent,
                                   aligns):
            #role_scores: amr x amr x rel
            graph = nx.DiGraph()
            n = len(concepts)
            role_scores = role_scores.view(n, n, -1).data
            max_non_score, max_non_score_id = role_scores[:, :, 1:].max(-1)
            max_non_score_id = max_non_score_id + 1
            non_score = role_scores[:, :, 0]
            active_cost = non_score - max_non_score  #so lowest cost edge gets to active first
            candidates = []
            h_vs = []
            for h_id in range(n):
                h, h_v = get_uni_var(concepts, h_id)
                h_vs.append(h_v)
                graph.add_node(h_v,
                               value=h,
                               align=aligns[h_id],
                               gold=True,
                               dep=dependent[h_id])

            constant_links = {}
            normal_edged_links = {}
            for h_id in range(n):
                for d_id in range(n):
                    if h_id != d_id:
                        r = rel_dict.getLabel(max_non_score_id[h_id, d_id])
                        r_inver = r + "-of" if not r.endswith(
                            "-of") else r[:-3]
                        h, h_v = get_uni_var(concepts, h_id)
                        d, d_v = get_uni_var(concepts, d_id)
                        if (concepts[h_id].is_constant()
                                or concepts[d_id].is_constant()):
                            if concepts[h_id].is_constant(
                            ) and concepts[d_id].is_constant():
                                continue
                            elif concepts[h_id].is_constant():
                                constant_links.setdefault(h_v, []).append(
                                    (active_cost[h_id, d_id], d_v, r, r_inver))
                            else:
                                constant_links.setdefault(d_v, []).append(
                                    (active_cost[h_id, d_id], h_v, r_inver, r))
                        elif active_cost[h_id, d_id] < 0:
                            if r in [":name-of", ":ARG0"
                                     ] and concepts[d_id].le in ["person"]:
                                graph.add_edge(h_v, d_v, role=r)
                                graph.add_edge(d_v, h_v, role=r_inver)
                            else:
                                #       if concepts[h_id].le == "name" and r != ":name-of":
                                #          r = ":name-of"
                                normal_edged_links.setdefault(
                                    (h_v, r), []).append(
                                        (active_cost[h_id,
                                                     d_id], d_v, r_inver))
                        else:
                            candidates.append(
                                (active_cost[h_id,
                                             d_id], (h_v, d_v, r, r_inver)))

            max_edge_per_node = 1 if not self.training else 100
            for h_v, r in normal_edged_links:
                sorted_list = sorted(normal_edged_links[h_v, r],
                                     key=lambda j: j[0])
                for _, d_v, r_inver in sorted_list[:max_edge_per_node]:
                    #     if graph.has_edge(h_v, d_v):
                    #         continue
                    graph.add_edge(h_v, d_v, role=r)
                    graph.add_edge(d_v, h_v, role=r_inver)
                for cost, d_v, r_inver in sorted_list[
                        max_edge_per_node:]:  #remaining
                    candidates.append((cost, (h_v, d_v, r, r_inver)))

            for h_v in constant_links:
                _, d_v, r, r_inver = sorted(constant_links[h_v],
                                            key=lambda j: j[0])[0]
                graph.add_edge(h_v, d_v, role=r)
                graph.add_edge(d_v, h_v, role=r_inver)

            candidates = sorted(candidates, key=lambda j: j[0])

            for _, (h_v, d_v, r, r_inver) in candidates:
                if nx.is_strongly_connected(graph):
                    break
                if not nx.has_path(graph, h_v, d_v):
                    graph.add_edge(h_v, d_v, role=r, force_connect=True)
                    graph.add_edge(d_v, h_v, role=r_inver, force_connect=True)

            _, root_v = get_uni_var(concepts, root_id)
            h_v = BOS_WORD
            root_symbol = AMRUniversal(BOS_WORD, BOS_WORD, NULL_WORD)
            graph.add_node(h_v, value=root_symbol, align=-1, gold=True, dep=1)
            graph.add_edge(h_v, root_v, role=":top")
            graph.add_edge(root_v, h_v, role=":top-of")

            if get_sense:
                for n, d in graph.nodes(True):
                    if "value" not in d:
                        print(n, d, graph[n], constant_links, graph.nodes,
                              graph.edges)
                    le, cat, sense = d["value"].le, d["value"].cat, d[
                        "value"].sense
                    if cat == Rule_Frame and sense == "":
                        sense = self.fragment_to_node_converter.get_senses(le)
                        d["value"] = AMRUniversal(le, cat, sense)

        #  if not self.training:
        #    assert nx.is_strongly_connected(graph),("before contraction",self.graph_to_quadruples(graph),graph_to_amr(graph))
        #     graph = contract_graph(graph)
        #    assert nx.is_strongly_connected(graph),("before contraction",self.graph_to_quadruples(graph),graph_to_amr(graph))

            if set_wiki:
                list = [[n, d] for n, d in graph.nodes(True)]
                for n, d in list:
                    for nearb in graph[n]:
                        r = graph[n][nearb]["role"]
                    if d["value"].le == "name":
                        names = []
                        head = None
                        for nearb in graph[n]:
                            r = graph[n][nearb]["role"]
                            if ":op" in r and "-of" not in r and int(
                                    graph[n][nearb]["role"][3:]) not in names:
                                names.append([
                                    graph.node[nearb]["value"],
                                    int(graph[n][nearb]["role"][3:])
                                ])
                            if r == ":name-of":
                                wikied = False
                                for nearbb in graph[nearb]:
                                    r = graph[nearb][nearbb]["role"]
                                    if r == ":wiki":
                                        wikied = True
                                        break
                                if not wikied:
                                    head = nearb
                        if head:
                            names = tuple([
                                t[0] for t in sorted(names, key=lambda t: t[1])
                            ])
                            wiki = self.fragment_to_node_converter.get_wiki(
                                names)
                            #    print (wiki)
                            wiki_v = Var(wiki.le + n._name)
                            graph.add_node(wiki_v,
                                           value=wiki,
                                           align=d["align"],
                                           gold=True,
                                           dep=2)  #second order  dependency
                            graph.add_edge(head, wiki_v, role=":wiki")
                            graph.add_edge(wiki_v, head, role=":wiki-of")

            assert nx.is_strongly_connected(graph), (
                "after contraction", self.graph_to_quadruples(graph),
                graph_to_amr(graph))
            return graph, self.graph_to_quadruples(graph)

        graphs = []
        quadruple_batch = []
        score_batch = myunpack(*srl_prob)  #list of (h x d)

        depedent_mark_batch = appended[0]
        aligns_batch = appended[1]
        for i, (role_scores, concepts, roots_score, dependent_mark,
                aligns) in enumerate(
                    zip(score_batch, srl_batch, roots, depedent_mark_batch,
                        aligns_batch)):
            root_s, root_id = roots_score.max(0)
            assert roots_score.size(0) == len(concepts), (concepts,
                                                          roots_score)
            root_id = root_id.data.tolist()[0]
            assert root_id < len(concepts), (concepts, roots_score)

            g, quadruples = create_connected_graph(role_scores, concepts,
                                                   root_id, dependent_mark,
                                                   aligns)
            graphs.append(g)
            quadruple_batch.append(quadruples)

        return graphs, quadruple_batch
Ejemplo n.º 10
0
    def relProbAndConToGraph(self,srl_batch, sourceBatch, srl_prob,roots,appended,get_sense=False,set_wiki=False,normalizeMod=False):
        """
        For EDS, there is normalizeMod
        """
        # TODO: graph connected graph given relation prob
        #batch of id
        #max_indices len,batch, n_feature
        #srcBatch  batch source
        #out batch AMRuniversal
        rel_dict = self.dicts["eds_rel_dict"]
        def get_uni_var(concepts,id):
            """
            for eds, id is also not important, just use the id as variable value
            """
            return concepts[id],EDSVar(str(id))

        def create_connected_graph(role_scores,concepts,root_id,dependent,aligns,source):
            #role_scores: amr x amr x rel
            graph = nx.MultiDiGraph()
            n = len(concepts)
            role_scores = role_scores.view(n,n,-1)
            max_non_score, max_non_score_id= role_scores[:,:,1:].max(-1)
            max_non_score_id = max_non_score_id +1
            non_score = role_scores[:,:,0]
            active_cost =   non_score - max_non_score  #so lowest cost edge gets to active first
            candidates = []
            # add all nodes
            for h_id in range(n):
                h,h_v = get_uni_var(concepts,h_id)
                graph.add_node(h_v, value=h, align=aligns[h_id],gold=True,dep = dependent[h_id])

            constant_links = {}
            normal_edged_links = {}
            # add all pairs of edges
            for h_id in range(n):
                for d_id in range(n):
                    if h_id != d_id:
                        r = rel_dict.getLabel(max_non_score_id[h_id,d_id].item())
                        # relations in rel_dict are forward argx, opx, sntx top
                        # and all backward relations.
                        # normalize mod should already be done in rel_dict
                        # we should normalize it here when connecting, make sure to not consider the same relation twices
                        r_inver = EDSGraph.get_inversed_edge(r)
                        h,h_v = get_uni_var(concepts,h_id)
                        d,d_v = get_uni_var(concepts,d_id)
                        # TODO: for mwe, compound ner
                        if active_cost[h_id,d_id] < 0:
                            normal_edged_links.setdefault((h_v,r),[]).append((active_cost[h_id,d_id],d_v,r_inver))
                        else:
                            candidates.append((active_cost[h_id,d_id],(h_v,d_v,r,r_inver)))

            max_edge_per_node = 1 if not self.training else 100
            for h_v,r in normal_edged_links:
                sorted_list = sorted(normal_edged_links[h_v,r],key = lambda j:j[0])
                for _,d_v,r_inver in sorted_list[:max_edge_per_node]:
               #     if graph.has_edge(h_v, d_v):
               #         continue
                    graph.add_edge(h_v, d_v, key=r, role=r)
                    graph.add_edge(d_v, h_v, key=r_inver, role=r_inver)
                for cost,d_v,r_inver in sorted_list[max_edge_per_node:]:  #remaining
                    candidates.append((cost,(h_v,d_v,r,r_inver)))

            candidates = sorted(candidates,key = lambda j:j[0])

            for _,(h_v,d_v,r,r_inver ) in candidates:
                if  nx.is_strongly_connected(graph):
                    break
                if not nx.has_path(graph,h_v,d_v):
                    graph.add_edge(h_v, d_v, key=r, role=r,force_connect=True)
                    graph.add_edge(d_v, h_v, key=r_inver, role=r_inver,force_connect=True)

            _,root_v  = get_uni_var(concepts,root_id)
            h_v = BOS_WORD
            root_symbol = EDSUniversal.TOP_EDSUniversal()
            graph.add_node(h_v, value=root_symbol, align=-1,gold=True,dep=1)
            graph.add_edge(h_v, root_v, key=":top", role=":top")
            graph.add_edge(root_v, h_v, key=":top-of", role=":top-of")

            for n,d in graph.nodes(True):
                le,pos,cat,sense,anchors = d["value"].le,d["value"].pos,d["value"].cat,d["value"].sense,d["value"].anchors
                # here align is the token index list
                if get_sense:
                    if sense == "" or sense == None:
                        sense = self.fragment_to_node_converter.get_senses(le, pos)

                anchors = []
                # convert token ids into character offset, with input source data
                tok_index = d["align"]
                if tok_index >= 0 and tok_index < len(source[ANCHOR_IND_SOURCE_BATCH]):
                    anchors.extend(source[ANCHOR_IND_SOURCE_BATCH][tok_index])

                d["value"] = EDSUniversal(pos,cat,sense,le,anchors)
                d["anchors"] = anchors

            if not nx.is_strongly_connected(graph):
                logger.warn("not connected after contraction: %s, %s, %s, %s, %s, %s, %s, %s",self.graph_to_quadruples(graph),graph_to_amr(graph), candidates, constant_links, graph.nodes(), graph.edges(), normal_edged_links, concepts)
            return graph,self.graph_to_quadruples(graph)

        graphs = []
        quadruple_batch = []
        score_batch = myunpack(*srl_prob) #list of (h x d)


        depedent_mark_batch = appended[0]
        aligns_batch = appended[1]
        for i,(role_scores,concepts,roots_score,dependent_mark,aligns,source) in enumerate(zip(score_batch,srl_batch,roots,depedent_mark_batch,aligns_batch,sourceBatch)):
            # here we assuming, there is one root
            root_s,root_id = roots_score.max(0)
            assert roots_score.size(0) == len(concepts),(concepts,roots_score)
            # in pytorch 0.4.0 to https://pytorch.org/docs/stable/tensors.html#torch.Tensor.tolist
            # tensor.tolist can be a int, when root_id is a scalar
            root_id = root_id.item()
            assert root_id < len(concepts),(concepts,roots_score)

            g,quadruples = create_connected_graph(role_scores,concepts,root_id,dependent_mark,aligns,source)
            graphs.append(g)
            quadruple_batch.append(quadruples)

        return graphs,quadruple_batch
Ejemplo n.º 11
0
    def relProbAndConToGraph(self,srl_batch, sourceBatch, srl_prob,roots,appended,get_sense=False,set_wiki=False,normalizeMod=False):
        #batch of id
        #max_indices len,batch, n_feature
        #srcBatch  batch source
        #out batch AMRuniversal
        amr_rel_dict = self.dicts["amr_rel_dict"]
        def get_uni_var(concepts,id):
            assert id < len(concepts),(id,concepts)
            uni = concepts[id]
            if  uni.le in [ "i" ,"it","you","they","he","she"] and uni.cat == Rule_Concept:
                return uni,Var(uni.le )
            le = uni.le
            if uni.cat != Rule_String:
                uni.le = uni.le.strip("/").strip(":")
                if ":" in uni.le or "/" in uni.le:
                    uni.cat = Rule_String
            if uni.le == "":
                return uni,Var(le+ str(id))
            return uni,Var(uni.le[0]+ str(id))

        def create_connected_graph(role_scores,concepts,root_id,dependent,aligns, normalizeMod=True):
            #role_scores: amr x amr x rel
            graph = nx.MultiDiGraph()
            n = len(concepts)
            role_scores = role_scores.view(n,n,-1)
            max_non_score, max_non_score_id= role_scores[:,:,1:].max(-1)
            max_non_score_id = max_non_score_id +1
            non_score = role_scores[:,:,0]
            active_cost =   non_score - max_non_score  #so lowest cost edge gets to active first
            candidates = []
            # add all nodes
            for h_id in range(n):
                h,h_v = get_uni_var(concepts,h_id)
                # here align is a list to the token index in our tokenization
                graph.add_node(h_v, value=h, align=aligns[h_id],gold=True,dep = dependent[h_id])

            constant_links = {}
            normal_edged_links = {}
            # add all pairs of edges
            for h_id in range(n):
                for d_id in range(n):
                    if h_id != d_id:
                        r = amr_rel_dict.getLabel(max_non_score_id[h_id,d_id].item())
                        # relations in amr_rel_dict are forward argx, opx, sntx top
                        # and all backward relations.
                        # normalize mod should already be done in amr_rel_dict
                        if normalizeMod and to_be_normalize_mod(r):
                            r = get_normalize_mod(r)
                        # we should normalize it here when connecting, make sure to not consider the same relation twices
                        r_inver = get_inversed_edge(r)
                        h,h_v = get_uni_var(concepts,h_id)
                        d,d_v = get_uni_var(concepts,d_id)
                        if  (concepts[h_id].is_constant() or concepts[d_id].is_constant() ):
                            if concepts[h_id].is_constant() and concepts[d_id].is_constant() :
                                continue
                            elif concepts[h_id].is_constant():
                                constant_links.setdefault(h_v,[]).append((active_cost[h_id,d_id],d_v,r,r_inver))
                            else:
                                constant_links.setdefault(d_v,[]).append((active_cost[h_id,d_id],h_v,r_inver,r))
                        elif active_cost[h_id,d_id] < 0:
                            if r in [":name-of" ,":ARG0"] and concepts[d_id].le in ["person"]:
                                # always adding two direction to support directed path for connectivity
                                graph.add_edge(h_v, d_v, key=r, role=r)
                                graph.add_edge(d_v, h_v, key=r_inver, role=r_inver)
                            else:
                         #       if concepts[h_id].le == "name" and r != ":name-of":
                          #          r = ":name-of"
                                normal_edged_links.setdefault((h_v,r),[]).append((active_cost[h_id,d_id],d_v,r_inver))
                        else:
                            candidates.append((active_cost[h_id,d_id],(h_v,d_v,r,r_inver)))

            max_edge_per_node = 1 if not self.training else 100
            for h_v,r in normal_edged_links:
                sorted_list = sorted(normal_edged_links[h_v,r],key = lambda j:j[0])
                for _,d_v,r_inver in sorted_list[:max_edge_per_node]:
               #     if graph.has_edge(h_v, d_v):
               #         continue
                    graph.add_edge(h_v, d_v, key=r, role=r)
                    graph.add_edge(d_v, h_v, key=r_inver, role=r_inver)
                for cost,d_v,r_inver in sorted_list[max_edge_per_node:]:  #remaining
                    candidates.append((cost,(h_v,d_v,r,r_inver)))


            for h_v in constant_links:
                _,d_v,r,r_inver = sorted(constant_links[h_v],key = lambda j:j[0])[0]
                graph.add_edge(h_v, d_v, key=r, role=r)
                graph.add_edge(d_v, h_v, key=r_inver, role=r_inver)

            candidates = sorted(candidates,key = lambda j:j[0])

            for _,(h_v,d_v,r,r_inver ) in candidates:
                if  nx.is_strongly_connected(graph):
                    break
                if not nx.has_path(graph,h_v,d_v):
                    graph.add_edge(h_v, d_v, key=r, role=r,force_connect=True)
                    graph.add_edge(d_v, h_v, key=r_inver, role=r_inver,force_connect=True)

            _,root_v  = get_uni_var(concepts,root_id)
            h_v = BOS_WORD
            root_symbol = AMRUniversal(BOS_WORD,BOS_WORD,NULL_WORD)
            graph.add_node(h_v, value=root_symbol, align=-1,gold=True,dep=1)
            graph.add_edge(h_v, root_v, key=":top", role=":top")
            graph.add_edge(root_v, h_v, key=":top-of", role=":top-of")

            if get_sense:
                for n,d in graph.nodes(True):
                    if "value" not in d:
                        print (n,d, graph[n],constant_links,graph.nodes,graph.edges)
                    le,cat,sense = d["value"].le,d["value"].cat,d["value"].sense
                    if cat == Rule_Frame and sense == "":
                        sense = self.fragment_to_node_converter.get_senses(le)
                        d["value"] = AMRUniversal(le,cat,sense)

            if not self.training:
                # when not training, it didn't do contratact
                if not nx.is_strongly_connected(graph):
                    logger.warn("not connected before contraction: %s, %s, %s, %s, %s, %s, %s, %s",self.graph_to_quadruples(graph),graph_to_amr(graph), candidates, constant_links, graph.nodes(), graph.edges(), normal_edged_links, concepts)
                graph = contract_graph(graph)

            if set_wiki:
                list = [[n,d]for n,d in graph.nodes(True)]
                for n,d in list:
                    if d["value"].le == "name":
                        names = []
                        head = None
                        for nearb in graph[n]:
                            for key, edge_data in graph[n][nearb].items():
                                r = edge_data["role"]
                                if ":op" in r and "-of" not in r and  int(edge_data["role"][3:]) not in names:
                                    names.append([graph.node[nearb]["value"], int(edge_data["role"][3:])])
                                if r == ":name-of":
                                    wikied = False
                                    for nearbb in graph[nearb]:
                                        for _, edge_data2 in graph[nearb][nearbb].items():
                                            r2 = edge_data2["role"]
                                            if r2 == ":wiki":
                                                wikied = True
                                                break
                                    if not wikied:
                                        head = nearb
                        if head:
                            names = tuple([t[0] for t in sorted(names,key = lambda t: t[1])])
                            wiki = self.fragment_to_node_converter.get_wiki(names)
                        #    print (wiki)
                            wiki_v = Var(wiki.le+n._name )
                            graph.add_node(wiki_v, value=wiki, align=d["align"],gold=True,dep=2) #second order  dependency
                            graph.add_edge(head, wiki_v, key=":wiki", role=":wiki")
                            graph.add_edge(wiki_v, head, key=":wiki-of",role=":wiki-of")

            if not nx.is_strongly_connected(graph):
                logger.warn("not connected after contraction: %s, %s, %s, %s, %s, %s, %s, %s",self.graph_to_quadruples(graph),graph_to_amr(graph), candidates, constant_links, graph.nodes(), graph.edges(), normal_edged_links, concepts)
            return graph,self.graph_to_quadruples(graph)

        graphs = []
        quadruple_batch = []
        score_batch = myunpack(*srl_prob) #list of (h x d)


        depedent_mark_batch = appended[0]
        #here aligns_batch is how final expaned node aligned to the categorized node
        aligns_batch = appended[1]
        for i,(role_scores,concepts,roots_score,dependent_mark,aligns) in enumerate(zip(score_batch,srl_batch,roots,depedent_mark_batch,aligns_batch)):
            root_s,root_id = roots_score.max(0)
            assert roots_score.size(0) == len(concepts),(concepts,roots_score)
            # in pytorch 0.4.0 to https://pytorch.org/docs/stable/tensors.html#torch.Tensor.tolist
            # tensor.tolist can be a int, when root_id is a scalar
            root_id = root_id.item()
            assert root_id < len(concepts),(concepts,roots_score)

            g,quadruples = create_connected_graph(role_scores,concepts,root_id,dependent_mark,aligns, normalizeMod=normalizeMod)
            graphs.append(g)
            quadruple_batch.append(quadruples)

        return graphs,quadruple_batch