コード例 #1
0
 def root_score(self,mypackedhead):
     heads = myunpack(*mypackedhead)
     output = []
     for head in heads:
         score = self.root(head).squeeze(1)
         output.append(self.LogSoftmax(score))
     return output
コード例 #2
0
 def getEmb(self,indexes,src_enc):
     head_emb,lengths = [],[]
     src_enc = myunpack(*src_enc)  #  pre_amr_l/src_l  x batch x dim
     for i, index in enumerate(indexes):
         enc = src_enc[i]  #src_l x  dim
         head_emb.append(enc[index])  #var(amr_l  x dim)
         lengths.append(len(index))
     return mypack(head_emb,lengths)
コード例 #3
0
    def forward(self, _,heads,deps):
        '''heads.data: mypacked        amr_l x rel_dim
            deps.data: mydoublepacked     amr_l x amr_l x rel_dim
        '''
        heads_data = heads.data
        deps_data = deps.data

        head_bilinear_transformed = self.bilinear (heads_data)  #all_data x (    n_rel x inputsize)

        head_bias_unpacked = myunpack(self.head_bias(heads_data),heads.lengths) #[len x n_rel]

        size = deps_data.size()
        dep_bias =  self.dep_bias(deps_data.view(-1,size[-1])).view(size[0],size[1],-1)

        dep_bias_unpacked,length_pairs = mydoubleunpack(MyDoublePackedSequence(MyPackedSequence( dep_bias,deps[0][1]),deps[1],dep_bias) ) #[len x n_rel]

        bilinear_unpacked = myunpack(head_bilinear_transformed,heads.lengths)

        deps_unpacked,length_pairs = mydoubleunpack(deps)
        output,l = self.bilinearForParallel( zip(bilinear_unpacked,deps_unpacked,head_bias_unpacked,dep_bias_unpacked),length_pairs)
        myscore_packed = mypack(output,l)

      #  prob_packed = MyPackedSequence(myscore_packed.data,l)
        return myscore_packed
コード例 #4
0
    def posteriorIndictedEmb(self,embs,posterior):
        #real alignment is sent in as list of index
        #variational relaxed posterior is sent in as MyPackedSequence

        #out   (batch x amr_len) x src_len x (dim+1)
        embs,src_len = unpack(embs)

        if isinstance(posterior,MyPackedSequence):
       #     print ("posterior is packed")
            posterior = myunpack(*posterior)
            embs = embs.transpose(0,1)
            out = []
            lengths = []
            amr_len = [len(p) for p in posterior]
            for i,emb in enumerate(embs):
                expanded_emb = emb.unsqueeze(0).expand([amr_len[i]]+[i for i in emb.size()]) # amr_len x src_len x dim
                indicator = posterior[i].unsqueeze(2)  # amr_len x src_len x 1
                out.append(torch.cat([expanded_emb,indicator],2))  # amr_len x src_len x (dim+1)
                lengths = lengths + [src_len[i]]*amr_len[i]
            data = torch.cat(out,dim=0)

            return pack(data, lengths, batch_first=True), amr_len

        elif isinstance(posterior,list):
            embs = embs.transpose(0,1)
            src_l = embs.size(1)
            amr_len = [len(i) for i in posterior]
            out = []
            lengths = []
            for i,emb in enumerate(embs):
                amr_l = len(posterior[i])
                expanded_emb = emb.unsqueeze(0).expand([amr_l]+[i for i in emb.size()]) # amr_len x src_len x dim
                indicator = emb.data.new(amr_l,src_l).zero_()
                indicator.scatter_(1, posterior[i].data.unsqueeze(1), 1.0) # amr_len x src_len x 1
                indicator = Variable(indicator.unsqueeze(2))
                out.append(torch.cat([expanded_emb,indicator],2))  # amr_len x src_len x (dim+1)
                lengths = lengths + [src_len[i]]*amr_l
            data = torch.cat(out,dim=0)

            return pack(data,lengths,batch_first=True),amr_len
コード例 #5
0
    def forward(self, input, index,src_enc):
        assert isinstance(input, MyPackedSequence),input
        input,lengths = input
        if self.alpha and self.training:
            input = data_dropout(input,self.alpha)
        cat_embed = self.cat_lut(input[:,AMR_CAT])
        lemma_embed = self.lemma_lut(input[:,AMR_LE])

        amr_emb = torch.cat([cat_embed,lemma_embed],1)
    #    print (input,lengths)

        head_emb_t,dep_emb_t,length_pairs = self.getEmb(index,src_enc)  #packed, mydoublepacked


        head_emb = torch.cat([amr_emb,head_emb_t.data],1)

        dep_amr_emb_t = myunpack(*MyPackedSequence(amr_emb,lengths))
        dep_amr_emb = [ emb.unsqueeze(0).expand(emb.size(0),emb.size(0),emb.size(-1))      for emb in dep_amr_emb_t]

        mydouble_amr_emb = mydoublepack(dep_amr_emb,length_pairs)

    #    print ("rel_encoder",mydouble_amr_emb.data.size(),dep_emb_t.data.size())
        dep_emb = torch.cat([mydouble_amr_emb.data,dep_emb_t.data],-1)

       # emb_unpacked = myunpack(emb,lengths)

        head_packed = MyPackedSequence(self.head(head_emb),lengths) #  total,rel_dim
        head_amr_packed = MyPackedSequence(amr_emb,lengths) #  total,rel_dim

   #     print ("dep_emb",dep_emb.size())
        size = dep_emb.size()
        dep = self.dep(dep_emb.view(-1,size[-1])).view(size[0],size[1],-1)

        dep_packed  = MyDoublePackedSequence(MyPackedSequence(dep,mydouble_amr_emb[0][1]),mydouble_amr_emb[1],dep)

        return  head_amr_packed,head_packed,dep_packed  #,MyPackedSequence(emb,lengths)
コード例 #6
0
ファイル: AMRProcessors.py プロジェクト: Peacefulyang/AMR_GNN
    def relProbAndConToGraph(self,
                             srl_batch,
                             srl_prob,
                             roots,
                             appended,
                             get_sense=False,
                             set_wiki=False):
        #batch of id
        #max_indices len,batch, n_feature
        #srcBatch  batch source
        #out batch AMRuniversal
        rel_dict = self.dicts["rel_dict"]

        def get_uni_var(concepts, id):
            assert id < len(concepts), (id, concepts)
            uni = concepts[id]
            if uni.le in ["i", "it", "you", "they", "he", "she"
                          ] and uni.cat == Rule_Concept:
                return uni, Var(uni.le)
            le = uni.le
            if uni.cat != Rule_String:
                uni.le = uni.le.strip("/").strip(":")
                if ":" in uni.le or "/" in uni.le:
                    uni.cat = Rule_String
            if uni.le == "":
                return uni, Var(le + str(id))
            return uni, Var(uni.le[0] + str(id))

        def create_connected_graph(role_scores, concepts, root_id, dependent,
                                   aligns):
            #role_scores: amr x amr x rel
            graph = nx.DiGraph()
            n = len(concepts)
            role_scores = role_scores.view(n, n, -1).data
            max_non_score, max_non_score_id = role_scores[:, :, 1:].max(-1)
            max_non_score_id = max_non_score_id + 1
            non_score = role_scores[:, :, 0]
            active_cost = non_score - max_non_score  #so lowest cost edge gets to active first
            candidates = []
            h_vs = []
            for h_id in range(n):
                h, h_v = get_uni_var(concepts, h_id)
                h_vs.append(h_v)
                graph.add_node(h_v,
                               value=h,
                               align=aligns[h_id],
                               gold=True,
                               dep=dependent[h_id])

            constant_links = {}
            normal_edged_links = {}
            for h_id in range(n):
                for d_id in range(n):
                    if h_id != d_id:
                        r = rel_dict.getLabel(max_non_score_id[h_id, d_id])
                        r_inver = r + "-of" if not r.endswith(
                            "-of") else r[:-3]
                        h, h_v = get_uni_var(concepts, h_id)
                        d, d_v = get_uni_var(concepts, d_id)
                        if (concepts[h_id].is_constant()
                                or concepts[d_id].is_constant()):
                            if concepts[h_id].is_constant(
                            ) and concepts[d_id].is_constant():
                                continue
                            elif concepts[h_id].is_constant():
                                constant_links.setdefault(h_v, []).append(
                                    (active_cost[h_id, d_id], d_v, r, r_inver))
                            else:
                                constant_links.setdefault(d_v, []).append(
                                    (active_cost[h_id, d_id], h_v, r_inver, r))
                        elif active_cost[h_id, d_id] < 0:
                            if r in [":name-of", ":ARG0"
                                     ] and concepts[d_id].le in ["person"]:
                                graph.add_edge(h_v, d_v, role=r)
                                graph.add_edge(d_v, h_v, role=r_inver)
                            else:
                                #       if concepts[h_id].le == "name" and r != ":name-of":
                                #          r = ":name-of"
                                normal_edged_links.setdefault(
                                    (h_v, r), []).append(
                                        (active_cost[h_id,
                                                     d_id], d_v, r_inver))
                        else:
                            candidates.append(
                                (active_cost[h_id,
                                             d_id], (h_v, d_v, r, r_inver)))

            max_edge_per_node = 1 if not self.training else 100
            for h_v, r in normal_edged_links:
                sorted_list = sorted(normal_edged_links[h_v, r],
                                     key=lambda j: j[0])
                for _, d_v, r_inver in sorted_list[:max_edge_per_node]:
                    #     if graph.has_edge(h_v, d_v):
                    #         continue
                    graph.add_edge(h_v, d_v, role=r)
                    graph.add_edge(d_v, h_v, role=r_inver)
                for cost, d_v, r_inver in sorted_list[
                        max_edge_per_node:]:  #remaining
                    candidates.append((cost, (h_v, d_v, r, r_inver)))

            for h_v in constant_links:
                _, d_v, r, r_inver = sorted(constant_links[h_v],
                                            key=lambda j: j[0])[0]
                graph.add_edge(h_v, d_v, role=r)
                graph.add_edge(d_v, h_v, role=r_inver)

            candidates = sorted(candidates, key=lambda j: j[0])

            for _, (h_v, d_v, r, r_inver) in candidates:
                if nx.is_strongly_connected(graph):
                    break
                if not nx.has_path(graph, h_v, d_v):
                    graph.add_edge(h_v, d_v, role=r, force_connect=True)
                    graph.add_edge(d_v, h_v, role=r_inver, force_connect=True)

            _, root_v = get_uni_var(concepts, root_id)
            h_v = BOS_WORD
            root_symbol = AMRUniversal(BOS_WORD, BOS_WORD, NULL_WORD)
            graph.add_node(h_v, value=root_symbol, align=-1, gold=True, dep=1)
            graph.add_edge(h_v, root_v, role=":top")
            graph.add_edge(root_v, h_v, role=":top-of")

            if get_sense:
                for n, d in graph.nodes(True):
                    if "value" not in d:
                        print(n, d, graph[n], constant_links, graph.nodes,
                              graph.edges)
                    le, cat, sense = d["value"].le, d["value"].cat, d[
                        "value"].sense
                    if cat == Rule_Frame and sense == "":
                        sense = self.fragment_to_node_converter.get_senses(le)
                        d["value"] = AMRUniversal(le, cat, sense)

            if not self.training:
                assert nx.is_strongly_connected(graph), (
                    "before contraction", self.graph_to_quadruples(graph),
                    graph_to_amr(graph))
                graph = contract_graph(graph)

            if set_wiki:
                list = [[n, d] for n, d in graph.nodes(True)]
                for n, d in list:
                    for nearb in graph[n]:
                        r = graph[n][nearb]["role"]
                    if d["value"].le == "name":
                        names = []
                        head = None
                        for nearb in graph[n]:
                            r = graph[n][nearb]["role"]
                            if ":op" in r and "-of" not in r and int(
                                    graph[n][nearb]["role"][3:]) not in names:
                                names.append([
                                    graph.node[nearb]["value"],
                                    int(graph[n][nearb]["role"][3:])
                                ])
                            if r == ":name-of":
                                wikied = False
                                for nearbb in graph[nearb]:
                                    r = graph[nearb][nearbb]["role"]
                                    if r == ":wiki":
                                        wikied = True
                                        break
                                if not wikied:
                                    head = nearb
                        if head:
                            names = tuple([
                                t[0] for t in sorted(names, key=lambda t: t[1])
                            ])
                            wiki = self.fragment_to_node_converter.get_wiki(
                                names)
                            #    print (wiki)
                            wiki_v = Var(wiki.le + n._name)
                            graph.add_node(wiki_v,
                                           value=wiki,
                                           align=d["align"],
                                           gold=True,
                                           dep=2)  #second order  dependency
                            graph.add_edge(head, wiki_v, role=":wiki")
                            graph.add_edge(wiki_v, head, role=":wiki-of")

            assert nx.is_strongly_connected(graph), (
                "after contraction", self.graph_to_quadruples(graph),
                graph_to_amr(graph))
            return graph, self.graph_to_quadruples(graph)

        graphs = []
        quadruple_batch = []
        score_batch = myunpack(*srl_prob)  #list of (h x d)

        depedent_mark_batch = appended[0]
        aligns_batch = appended[1]
        for i, (role_scores, concepts, roots_score, dependent_mark,
                aligns) in enumerate(
                    zip(score_batch, srl_batch, roots, depedent_mark_batch,
                        aligns_batch)):
            root_s, root_id = roots_score.max(0)
            assert roots_score.size(0) == len(concepts), (concepts,
                                                          roots_score)
            root_id = root_id.data.tolist()[0]
            assert root_id < len(concepts), (concepts, roots_score)

            g, quadruples = create_connected_graph(role_scores, concepts,
                                                   root_id, dependent_mark,
                                                   aligns)
            graphs.append(g)
            quadruple_batch.append(quadruples)

        return graphs, quadruple_batch