Exemple #1
0
    def forward(self, g: Graph, out: str = None, greedy=True):
        out_tokens = self.fix_out(out)

        # Encoding
        nodes = {
            n: self.vocab.lookup(
                self.word_dropout(n, self.entity_dropout if out else 0))
            for n in g.nodes
        }
        unique_edges = set(chain.from_iterable(g.edges.values()))
        edges = {
            e: self.vocab.lookup(
                self.word_dropout(e, self.relation_dropout if out else 0))
            for e in unique_edges
        }

        node_connections = {node: ([], []) for node in g.nodes}
        for ((n1, n2), es) in g.edges.items():
            for e in es:
                edge_rep = edges[e]
                node_connections[n2][0].append(edge_rep)
                node_connections[n1][1].append(edge_rep)

        ne_rep = lambda e: dy.esum(e) if len(e) > 0 else self.no_ent
        node_reps = {
            n: self.entity_encoder * dy.concatenate([
                ne_rep(node_connections[n][0]), nodes[n],
                ne_rep(node_connections[n][1])
            ])
            for n in g.nodes
        }

        edge_reps = [
            self.entity_encoder *
            dy.concatenate([node_reps[n1], edges[e], node_reps[n2]])
            for ((n1, n2), es) in g.edges.items() for e in es
        ]

        # In decoding time we will remove 1 RDF at a time until none is left.
        rdfs = {((n1, n2), e): edge_reps[i]
                for i, ((n1, n2), es) in enumerate(g.edges.items())
                for e in es}

        # Decoding
        decoder = self.decoder.initial_state().add_input(
            dy.average(edge_reps))  # Initialize with something

        nodes_stack = []

        def choose(item):
            if out_tokens:
                out_tokens.pop(0)

            if item[0] == "pop":
                nodes_stack.pop()
                res = [item[1]]
            elif item[0] == "node":
                nodes_stack.append(item[1])
                res = [item[1]]
            elif item[0] == "edge":
                _, d, e, n = item
                prev_node = nodes_stack[-1]
                nodes_stack.append(n)
                res = [d, e, "[", n]
                if d == ">":
                    del rdfs[(prev_node, n), e]
                elif d == "<":
                    del rdfs[(n, prev_node), e]
                else:
                    raise ValueError("direction can only be > or <. got " + d)
            else:
                raise ValueError("type can only be: pop, node, edge. got " +
                                 item[0])

            for w in res:
                if w in node_reps:
                    vec = node_reps[w]
                elif w in edges:
                    vec = edges[w]
                else:
                    vec = self.vocab.lookup(w)
                decoder.add_input(vec)

            return res

        is_pop = False
        while len(rdfs) > 0:
            # Possible vocab
            if len(nodes_stack) == 0:
                is_pop = False
                vocab = {("node", n): node_reps[n]
                         for n in set(
                             chain.from_iterable([ns
                                                  for ns, e in rdfs.keys()]))}
            else:
                last_node = nodes_stack[-1]
                f_edges = {("edge", ">", e, n2): rep
                           for ((n1, n2), e), rep in rdfs.items()
                           if n1 == last_node}
                b_edges = {("edge", "<", e, n1): rep
                           for ((n1, n2), e), rep in rdfs.items()
                           if n2 == last_node}
                vocab = {**f_edges, **b_edges}

                if is_pop:
                    pop_char = "." if len(nodes_stack) == 1 else "]"
                    vocab[("pop", pop_char)] = self.vocab.lookup(pop_char)
                is_pop = True

            vocab_list = list(vocab.items())
            vocab_index = ["_".join(i[1:]) for i, _ in vocab_list]

            try:
                if len(vocab_list) == 1:
                    if out:
                        assert out_tokens[0] == vocab_index[0]

                    choice = choose(vocab_list[0][0])
                    if not out:
                        yield choice

                    continue

                vocab_matrix = dy.transpose(
                    dy.concatenate_cols([rep for _, rep in vocab_list]))
                pred_vec = vocab_matrix * decoder.output()
                if out:
                    best_i = vocab_index.index(out_tokens[0])
                    choose(vocab_list[best_i][0])
                    yield dy.pickneglogsoftmax(pred_vec, best_i)
                else:
                    if greedy:
                        best_i = int(np.argmax(pred_vec.npvalue()))
                    else:
                        best_i = arg_sample(
                            list(dy.softmax(pred_vec).npvalue()))

                    yield choose(vocab_list[best_i][0])
            except Exception as e:
                print()
                print("is_pop", is_pop)
                print("out", out)
                print("out tokens", out_tokens)
                print("vocab_index", vocab_index)
                print("original_rdf", g.as_rdf())
                print("rdf", list(rdfs.keys()))
                print()
                raise e

        if not out:
            yield ["]"]
Exemple #2
0
 def convert_graph(self, g: Graph):
     rdf = [(concat_entity(s), self.convert_relation(r), concat_entity(o))
            for s, r, o in g.as_rdf()]
     return Graph(rdf)