def forward(self, g: Graph, out: str = None, greedy=True): out_tokens = self.fix_out(out) # Encoding nodes = { n: self.vocab.lookup( self.word_dropout(n, self.entity_dropout if out else 0)) for n in g.nodes } unique_edges = set(chain.from_iterable(g.edges.values())) edges = { e: self.vocab.lookup( self.word_dropout(e, self.relation_dropout if out else 0)) for e in unique_edges } node_connections = {node: ([], []) for node in g.nodes} for ((n1, n2), es) in g.edges.items(): for e in es: edge_rep = edges[e] node_connections[n2][0].append(edge_rep) node_connections[n1][1].append(edge_rep) ne_rep = lambda e: dy.esum(e) if len(e) > 0 else self.no_ent node_reps = { n: self.entity_encoder * dy.concatenate([ ne_rep(node_connections[n][0]), nodes[n], ne_rep(node_connections[n][1]) ]) for n in g.nodes } edge_reps = [ self.entity_encoder * dy.concatenate([node_reps[n1], edges[e], node_reps[n2]]) for ((n1, n2), es) in g.edges.items() for e in es ] # In decoding time we will remove 1 RDF at a time until none is left. rdfs = {((n1, n2), e): edge_reps[i] for i, ((n1, n2), es) in enumerate(g.edges.items()) for e in es} # Decoding decoder = self.decoder.initial_state().add_input( dy.average(edge_reps)) # Initialize with something nodes_stack = [] def choose(item): if out_tokens: out_tokens.pop(0) if item[0] == "pop": nodes_stack.pop() res = [item[1]] elif item[0] == "node": nodes_stack.append(item[1]) res = [item[1]] elif item[0] == "edge": _, d, e, n = item prev_node = nodes_stack[-1] nodes_stack.append(n) res = [d, e, "[", n] if d == ">": del rdfs[(prev_node, n), e] elif d == "<": del rdfs[(n, prev_node), e] else: raise ValueError("direction can only be > or <. got " + d) else: raise ValueError("type can only be: pop, node, edge. got " + item[0]) for w in res: if w in node_reps: vec = node_reps[w] elif w in edges: vec = edges[w] else: vec = self.vocab.lookup(w) decoder.add_input(vec) return res is_pop = False while len(rdfs) > 0: # Possible vocab if len(nodes_stack) == 0: is_pop = False vocab = {("node", n): node_reps[n] for n in set( chain.from_iterable([ns for ns, e in rdfs.keys()]))} else: last_node = nodes_stack[-1] f_edges = {("edge", ">", e, n2): rep for ((n1, n2), e), rep in rdfs.items() if n1 == last_node} b_edges = {("edge", "<", e, n1): rep for ((n1, n2), e), rep in rdfs.items() if n2 == last_node} vocab = {**f_edges, **b_edges} if is_pop: pop_char = "." if len(nodes_stack) == 1 else "]" vocab[("pop", pop_char)] = self.vocab.lookup(pop_char) is_pop = True vocab_list = list(vocab.items()) vocab_index = ["_".join(i[1:]) for i, _ in vocab_list] try: if len(vocab_list) == 1: if out: assert out_tokens[0] == vocab_index[0] choice = choose(vocab_list[0][0]) if not out: yield choice continue vocab_matrix = dy.transpose( dy.concatenate_cols([rep for _, rep in vocab_list])) pred_vec = vocab_matrix * decoder.output() if out: best_i = vocab_index.index(out_tokens[0]) choose(vocab_list[best_i][0]) yield dy.pickneglogsoftmax(pred_vec, best_i) else: if greedy: best_i = int(np.argmax(pred_vec.npvalue())) else: best_i = arg_sample( list(dy.softmax(pred_vec).npvalue())) yield choose(vocab_list[best_i][0]) except Exception as e: print() print("is_pop", is_pop) print("out", out) print("out tokens", out_tokens) print("vocab_index", vocab_index) print("original_rdf", g.as_rdf()) print("rdf", list(rdfs.keys())) print() raise e if not out: yield ["]"]
def convert_graph(self, g: Graph): rdf = [(concat_entity(s), self.convert_relation(r), concat_entity(o)) for s, r, o in g.as_rdf()] return Graph(rdf)