def forward(self, *t_list): out_list = [] for t in t_list: t.set_n_initializer(dgl.init.zero_initializer) # apply type module if self.type_module is not None: type_mask = (t.ndata['t'] != ConstValues.NO_ELEMENT) t.ndata['t_embs'] = self.type_module(t.ndata['t'] * type_mask) * type_mask.view(-1, 1) # apply input module if self.input_module is not None: x_mask = (t.ndata['x'] != ConstValues.NO_ELEMENT) t.ndata['x_mask'] = x_mask t.ndata['x_embs'] = self.input_module(t.ndata['x'] * x_mask) * x_mask.view(-1, 1) # propagate dgl.prop_nodes_topo(t, message_func=self.cell_module.message_func, reduce_func=self.cell_module.reduce_func, apply_node_func=self.cell_module.apply_node_func) # return the hidden h = t.ndata['h'] if self.only_root_state: root_ids = [i for i in range(t.number_of_nodes()) if t.out_degree(i) == 0] out_list.append(h[root_ids]) else: out_list.append(h) if self.output_module is not None: return self.output_module(*out_list) else: return out_list
def forward(self, graph: dgl.DGLGraph) -> torch.Tensor: x = self._dropout(graph.ndata["x"]) # init matrices for message propagation number_of_nodes = graph.number_of_nodes() graph.ndata["x_iou"] = self._W_iou(x) graph.ndata["x_f"] = self._W_f(x) graph.ndata["h"] = graph.ndata["x"].new_zeros( (number_of_nodes, self._encoder_size)) graph.ndata["c"] = graph.ndata["x"].new_zeros( (number_of_nodes, self._encoder_size)) graph.ndata["Uh_sum"] = graph.ndata["x"].new_zeros( (number_of_nodes, 3 * self._encoder_size)) graph.ndata["fc_sum"] = graph.ndata["x"].new_zeros( (number_of_nodes, self._encoder_size)) # propagate nodes dgl.prop_nodes_topo( graph, message_func=self.message_func, reduce_func=self.reduce_func, apply_node_func=self.apply_node_func, ) # [n nodes; encoder size] h = graph.ndata.pop("h") # [n nodes; decoder size] out = self._tanh(self._norm(self._out_linear(h))) return out
def forward(self, batch, h, c): ''' Compute tree-lstm prediction given a batch. :param batch: dgl.data.SSTBatch :param h: Tensor initial hidden state :param c: Tensor initial cell state :return: logits : Tensor The prediction of each node. ''' g = batch.graph # to heterogenous graph g = dgl.graph(g.edges()) # feed embedding embeds = self.embedding(batch.wordid * batch.mask) g.ndata['iou'] = self.cell.W_iou( self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) g.ndata['h'] = h g.ndata['c'] = c # propagate dgl.prop_nodes_topo(g, message_func=self.cell.message_func, reduce_func=self.cell.reduce_func, apply_node_func=self.cell.apply_node_func) # compute logits h = self.dropout(g.ndata.pop('h')) logits = self.linear(h) return logits
def forward(self, batch, g, h, c): """Compute tree-lstm prediction given a batch. Parameters ---------- batch : dgl.data.SSTBatch The data batch. g : dgl.DGLGraph Tree for computation. h : Tensor Initial hidden state. c : Tensor Initial cell state. Returns ------- logits : Tensor The prediction of each node. """ # feed embedding embeds = self.embedding(batch['x'].val.squeeze(-1)) g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch['mask'].float().unsqueeze(-1) g.ndata['h'] = h g.ndata['c'] = c # propagate dgl.prop_nodes_topo(g, self.cell.message_func, self.cell.reduce_func, apply_node_func=self.cell.apply_node_func) # compute logits h = self.dropout(g.ndata.pop('h')) logits = self.linear(h) return logits
def forward(self, batch, g, h, c): """Compute tree-lstm prediction given a batch. Parameters ---------- batch : dgl.data.SSTBatch The data batch. h : Tensor Initial hidden state. c : Tensor Initial cell state. Returns ------- out """ g.ndata['iou'] = self.cell.W_iou(batch.X) g.ndata['h'] = h g.ndata['c'] = c dgl.prop_nodes_topo(g, message_func=self.cell.message_func, reduce_func=self.cell.reduce_func, apply_node_func=self.cell.apply_node_func) h = g.ndata.pop('h') head_ids = th.nonzero(batch.isroot, as_tuple=False).flatten() head_h = th.index_select(h, 0, head_ids) out = head_h return out
def forward(self, batch, enc_hidden, list_root_index, list_num_node): g = batch.graph g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) g.register_apply_node_func(self.cell.apply_node_func) g.ndata['h'] = enc_hidden[0] g.ndata['c'] = enc_hidden[1] wemb = self.wemb(batch.wordid * batch.mask) g.ndata['iou'] = self.cell.W_iou(wemb) * batch.mask.float().unsqueeze(-1) dgl.prop_nodes_topo(g) all_node_h_in_batch = g.ndata.pop('h') all_node_c_in_batch = g.ndata.pop('c') if self.opt.tree_lstm_output_type == "root_node": root_node_h_in_batch, root_node_c_in_batch = [], [] add_up_num_node = 0 for _i in range(len(list_root_index)): if _i - 1 < 0: add_up_num_node = 0 else: add_up_num_node += list_num_node[_i - 1] idx_to_query = list_root_index[_i] + add_up_num_node root_node_h_in_batch.append(all_node_h_in_batch[idx_to_query]) root_node_c_in_batch.append(all_node_c_in_batch[idx_to_query]) root_node_h_in_batch = torch.cat(root_node_h_in_batch).reshape(1, len(root_node_h_in_batch), -1) root_node_c_in_batch = torch.cat(root_node_c_in_batch).reshape(1, len(root_node_c_in_batch), -1) return root_node_h_in_batch, root_node_c_in_batch elif self.opt.tree_lstm_output_type == "no_reduce": return all_node_h_in_batch, all_node_c_in_batch
def forward(self, g, x, mask): """Compute tree-lstm prediction given a batch. Parameters ---------- batch : TreeDataset.TreeBatch The data batch. h : Tensor Initial hidden state. c : Tensor Initial cell state. Returns ------- logits : Tensor The prediction of each node. """ g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) g.register_apply_node_func(self.cell.apply_node_func) # feed embedding # TODO: x * mask does not make sense. use x[mask] embeds = self.input_module(x * mask) g.ndata['iou_input'] = self.W_iou(embeds) * mask.float().unsqueeze(-1) g.ndata['f_input'] = self.W_f(embeds) * mask.float().unsqueeze(-1) # propagate dgl.prop_nodes_topo(g) # compute output h = g.ndata['h'] out = self.output_module(h) return out
def forward(self, batch, h, c): """Compute tree-lstm prediction given a batch. Parameters ---------- batch : dgl.data.SSTBatch The data batch. h : Tensor Initial hidden state. c : Tensor Initial cell state. Returns ------- logits : Tensor The prediction of each node. """ g = batch.graph # feed embedding embeds = self.embedding(batch.wordid * batch.mask) wiou = self.cell.W_iou(self.dropout(embeds)) g.ndata['iou'] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype) g.ndata['h'] = h g.ndata['c'] = c # propagate dgl.prop_nodes_topo(g, message_func=self.cell.message_func, reduce_func=self.cell.reduce_func, apply_node_func=self.cell.apply_node_func) # compute logits h = self.dropout(g.ndata.pop('h')) logits = self.linear(h) return logits
def forward(self, batch, h, c): """Compute tree-lstm prediction given a batch. Parameters ---------- batch : dgl.data.SSTBatch The data batch. h : Tensor Initial hidden state. c : Tensor Initial cell state. Returns ------- logits : Tensor The prediction of each node. """ g = batch.graph g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) g.register_apply_node_func(self.cell.apply_node_func) # feed embedding embeds = self.embedding(batch.wordid * batch.mask) g.ndata['iou'] = self.cell.W_iou( self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) g.ndata['h'] = h g.ndata['c'] = c # propagate dgl.prop_nodes_topo(g) # compute logits h = self.dropout(g.ndata.pop('h')) logits = self.linear(h) return logits
def forward(self, src, enc_tree, dec_tree, return_code=False, **kwargs): self._init_graph(enc_tree, dec_tree) # Soruce encoding src_mask = torch.ne(src, 0).float() encoder_outputs = self.encode_source(src, src_mask, meanpool=True) encoder_states = encoder_outputs["encoder_states"] # Tree encoding enc_x = enc_tree.ndata["x"].cuda() x_embeds = self.label_embed_layer(enc_x) enc_tree.ndata['iou'] = self.enc_cell.W_iou(self.dropout(x_embeds)) enc_tree.ndata['h'] = torch.zeros( (enc_tree.number_of_nodes(), self.hidden_size)).cuda() enc_tree.ndata['c'] = torch.zeros( (enc_tree.number_of_nodes(), self.hidden_size)).cuda() enc_tree.ndata['mask'] = enc_tree.ndata['mask'].float().cuda() dgl.prop_nodes_topo(enc_tree) # Obtain root representation root_mask = enc_tree.ndata["mask"].float().cuda() # root_idx = torch.arange(root_mask.shape[0])[root_mask > 0].cuda() root_h = self.dropout( enc_tree.ndata.pop("h")) * root_mask.unsqueeze(-1) orig_h = root_h.clone()[root_mask > 0] partial_h = orig_h if self._without_source: partial_h += encoder_states # Discretization if self._code_bits > 0: if return_code: codes = self.semhash(partial_h, return_code=True) ret = {"codes": codes} return ret else: partial_h = self.semhash(partial_h) if not self._without_source: partial_h += encoder_states root_h[root_mask > 0] = partial_h # Tree decoding dec_x = dec_tree.ndata["x"].cuda() dec_embeds = self.label_embed_layer(dec_x) dec_tree.ndata['iou'] = self.dec_cell.W_iou(self.dropout(dec_embeds)) dec_tree.ndata['h'] = root_h dec_tree.ndata['c'] = torch.zeros( (enc_tree.number_of_nodes(), self.hidden_size)).cuda() dec_tree.ndata['mask'] = dec_tree.ndata['mask'].float().cuda() dgl.prop_nodes_topo(dec_tree) # Compute logits all_h = self.dropout(dec_tree.ndata.pop("h")) logits = self.logit_nn(all_h) logp = F.log_softmax(logits, 1) # Compute loss y_labels = dec_tree.ndata["y"].cuda() monitor = {} loss = F.nll_loss(logp, y_labels, reduction="mean") acc = (logits.argmax(1) == y_labels).float().mean() monitor["loss"] = loss monitor["label_accuracy"] = acc return monitor
def forward(self, batch): """Compute tree-lstm prediction given a batch. Parameters ---------- batch : dgl.data.SSTBatch The data batch. h : Tensor Initial hidden state. c : Tensor Initial cell state. Returns ------- logits : Tensor The prediction of each node. """ #----------utils function--------------- def InitS(tree): tree.ndata['s'] = tree.ndata['e'].mean(dim=0).repeat(tree.number_of_nodes(), 1) return tree def updateS(tree, state): assert state.dim() == 1 tree.ndata['s'] = state.repeat(tree.number_of_nodes(), 1) return tree def extractS(batchTree): # [dmodel] --> [[dmodel]] --> [tree, dmodel] --> [tree, 1, dmodel] s_list = [tree.ndata.pop('s')[0].unsqueeze(0) for tree in dgl.unbatch(batchTree)] return th.cat(s_list, dim=0).unsqueeze(1) def extractH(batchTree): # [nodes, dmodel] --> [nodes, dmodel]--> [max_nodes, dmodel]--> [tree*_max_nodes, dmodel] --> [tree, max_nodes, dmodel] h_list = [tree.ndata.pop('h') for tree in dgl.unbatch(batchTree)] max_nodes = max([h.size(0) for h in h_list]) h_list = [th.cat([h, th.zeros([max_nodes-h.size(0), h.size(1)]).to(self.device)], dim=0).unsqueeze(0) for h in h_list] return th.cat(h_list, dim=0) #----------------------------------------- g = batch.graph # feed embedding embeds = self.embedding(batch.wordid * batch.mask) g.ndata['c'] = th.zeros((g.number_of_nodes(), 2, self.dmodel)).to(self.device) g.ndata['e'] = embeds*batch.mask.float().unsqueeze(-1) g.ndata['h'] = embeds*batch.mask.float().unsqueeze(-1) g = dgl.batch([InitS(gg) for gg in dgl.unbatch(g)]) # propagate for i in range(self.T_step): g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) g.register_apply_node_func(self.cell.apply_node_func) dgl.prop_nodes_topo(g) States = self.cell.updateGlobalVec(extractS(g), extractH(g) ) g = dgl.batch([updateS(tree, state) for (tree, state) in zip(dgl.unbatch(g), States)]) # compute logits h = self.dropout(g.ndata.pop('h')) logits = self.linear(h) return logits
def propagate(self, g, cell, X, h, c): g.ndata['iou'] = cell.W_iou(X) g.ndata['h'] = h g.ndata['c'] = c dgl.prop_nodes_topo(g, message_func=cell.message_func, reduce_func=cell.reduce_func, apply_node_func=cell.apply_node_func) return g.ndata.pop('h')
def run(cell, graph, iou, h, c): g = graph g.register_message_func(cell.message_func) g.register_reduce_func(cell.reduce_func) g.register_apply_node_func(cell.apply_node_func) # feed embedding g.ndata["iou"] = iou g.ndata["h"] = h g.ndata["c"] = c # propagate dgl.prop_nodes_topo(g) return g.ndata.pop("h")
def propagate(self, g, cell, X, h, c): g.register_message_func(cell.message_func) g.register_reduce_func(cell.reduce_func) g.register_apply_node_func(cell.apply_node_func) #g.ndata['iou'] = cell.d(cell.W_iou(X)) g.ndata['iou'] = cell.W_iou(X) g.ndata['h'] = h g.ndata['c'] = c # propagate dgl.prop_nodes_topo(g) return g.ndata.pop('h')
def forward(self, batch: BatchedTree): batches = [deepcopy(batch) for _ in range(self.num_stacks)] for stack in range(self.num_stacks): cur_batch = batches[stack] if stack > 0: prev_batch = batches[stack - 1] cur_batch.batch_dgl_graph.ndata['x'] = prev_batch.batch_dgl_graph.ndata['h'] cur_batch.batch_dgl_graph.register_message_func(self.cell.message_func) cur_batch.batch_dgl_graph.register_reduce_func(self.cell.reduce_func) cur_batch.batch_dgl_graph.register_apply_node_func(self.cell.apply_node_func) cur_batch.batch_dgl_graph.ndata['ruo'] = self.cell.W_ruo(self.dropout(batch.batch_dgl_graph.ndata['x'])) dgl.prop_nodes_topo(cur_batch.batch_dgl_graph) return batches
def forward(self, rnn_inputs: th.Tensor, graph: dgl.DGLGraph): g = graph g.register_message_func(self.tree_lstm.message_func) g.register_reduce_func(self.tree_lstm.reduce_func) g.register_apply_node_func(self.tree_lstm.apply_node_func) g.ndata['iou'] = self.tree_lstm.W_iou(rnn_inputs) g.ndata['h'] = self.linear(rnn_inputs) g.ndata['c'] = th.autograd.Variable(th.zeros( (graph.number_of_nodes(), self.output_dim), dtype=th.float32), requires_grad=False).to(self.device) dgl.prop_nodes_topo(g) return g.ndata.pop('h')
def forward(self, graph: dgl.DGLGraph) -> torch.Tensor: """Apply transformer encoder :param graph: batched dgl graph :return: encoded nodes [number of nodes, hidden size] """ graph.ndata['h'] = graph.ndata['x'].new_zeros( (graph.number_of_nodes(), self.h_enc)) dgl.prop_nodes_topo(graph, message_func=[dgl.function.copy_u('h', 'h')], reduce_func=self.reduce_func, apply_node_func=self.apply_node_func) return graph.ndata.pop('h')
def forward(self, batch): g = batch.graph # Set the function defined to be the default message function g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) embeds = self.embedding(batch.wordid * batch.mask) g.ndata['p'] = self.dropout(embeds) # add dropout --> improve BestRoot_acc_test # Define traversal:trigger the message passing dgl.prop_nodes_topo(g) # Message propagation using node frontiers generated by topological order. p = self.dropout(g.ndata.pop('p')) # pop():get and remove node states from the graph-->save memory # compute logits logits = self.W_label(p) return logits
def forward(self, batch, h, c): """Compute tree-lstm prediction given a batch. Parameters ---------- batch : dgl.data.SSTBatch The data batch. h : Tensor Initial hidden state. c : Tensor Initial cell state. Returns ------- logits : Tensor The prediction of each node. """ g = batch.graph g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) g.register_apply_node_func(self.cell.apply_node_func) # feed embedding # embeds = self.embedding(batch.wordid) # g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) g.ndata['iou'] = self.cell.W_iou(batch.wordid) g.ndata['h'] = h g.ndata['c'] = c # propagate dgl.prop_nodes_topo(g) # compute logits h = self.dropout(g.ndata.pop('h')) # logits = self.linear(h) # attention part if self.attention: root_node_h_in_batch = [] root_idx = [] result_idx = 0 for idx in g.batch_num_nodes: root_node_h_in_batch.append(h[result_idx]) root_idx.append(result_idx) result_idx = result_idx + idx root_node_h_in_batch = th.cat(root_node_h_in_batch).reshape(len(root_idx), -1) node_num = th.tensor(batch.graph.batch_num_nodes) node_num = node_num.to(root_node_h_in_batch.device) h = self.tree_attn(h, root_node_h_in_batch, node_num) h = F.relu(self.wh(h)) logits = self.linear(h) return logits
def forward(self, batch, h, c): g = batch.graph g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) g.register_apply_node_func(self.cell.apply_node_func) # feed embedding embeds = self.embedding(batch.wordid * batch.mask) g.ndata['iou'] = self.cell.W_iou( self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) g.ndata['h'] = h g.ndata['c'] = c # propagate dgl.prop_nodes_topo(g) # compute logits h = self.dropout(g.ndata.pop('h')) logits = self.linear(h) return logits
def forward(self, graph: dgl.BatchedDGLGraph, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: number_of_nodes = graph.number_of_nodes() graph.ndata['x_iou'] = self.W_iou(graph.ndata['x']) + self.b_iou graph.ndata['x_f'] = self.W_f(graph.ndata['x']) + self.b_f graph.ndata['h'] = torch.zeros((number_of_nodes, self.h_size), device=device) graph.ndata['c'] = torch.zeros((number_of_nodes, self.h_size), device=device) graph.ndata['Uh_sum'] = torch.zeros((number_of_nodes, 3 * self.h_size), device=device) graph.ndata['fc_sum'] = torch.zeros((number_of_nodes, self.h_size), device=device) graph.register_message_func(self.message_func) graph.register_apply_node_func(self.apply_node_func) dgl.prop_nodes_topo(graph, reduce_func=[fn.sum('Uh', 'Uh_sum'), fn.sum('fc', 'fc_sum')]) h = graph.ndata.pop('h') c = graph.ndata.pop('c') return h, c
def forward(self, graph: dgl.BatchedDGLGraph, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: # register function for message passing graph.register_reduce_func(self.reduce_func) graph.register_apply_node_func(self.apply_node_func) features = graph.ndata['x'] nodes_in_batch = graph.number_of_nodes() graph.ndata['node_iou'] = self.W_iou(features) + self.b_iou graph.ndata['node_f'] = self.W_f(features) + self.b_f graph.ndata['h'] = torch.zeros(nodes_in_batch, self.h_size).to(device) graph.ndata['c'] = torch.zeros(nodes_in_batch, self.h_size).to(device) graph.ndata['Uh_tilda'] = torch.zeros(nodes_in_batch, 3 * self.h_size).to(device) # propagate dgl.prop_nodes_topo(graph, message_func=[fn.copy_u('h', 'h'), fn.copy_u('c', 'c')]) # get encoded output h = graph.ndata.pop('h') c = graph.ndata.pop('c') return h, c
def forward(self, batch, h, c): g = batch.graph g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) g.register_apply_node_func(self.cell.apply_node_func) # (num_nodes, 256) embeds = self.embedding(batch.wordid * batch.mask) g.ndata['iou'] = self.cell.W_iou( self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) g.ndata['h'] = h g.ndata['c'] = c dgl.prop_nodes_topo(g) x = unbatch(g) y = x[0].nodes[0].data['h'] h = self.dropout(g.ndata.pop('h')) logits = self.linear(h) return logits
def test_prop_nodes_topo(): # bi-directional chain g = dgl.DGLGraph(nx.path_graph(5)) # tree tree = dgl.DGLGraph() tree.add_nodes(5) tree.add_edge(1, 0) tree.add_edge(2, 0) tree.add_edge(3, 2) tree.add_edge(4, 2) tree.register_message_func(mfunc) tree.register_reduce_func(rfunc) # init node feature data tree.ndata['x'] = mx.nd.zeros(shape=(5, 2)) # set all leaf nodes to be ones tree.nodes[[1, 3, 4]].data['x'] = mx.nd.ones(shape=(3, 2)) dgl.prop_nodes_topo(tree) # root node get the sum assert np.allclose(tree.nodes[0].data['x'].asnumpy(), np.array([[3., 3.]]))
def test_prop_nodes_topo(idtype): # bi-directional chain g = dgl.graph(nx.path_graph(5), idtype=idtype, device=F.ctx()) assert U.check_fail(dgl.prop_nodes_topo, g) # has loop # tree tree = dgl.DGLGraph() tree.add_nodes(5) tree.add_edge(1, 0) tree.add_edge(2, 0) tree.add_edge(3, 2) tree.add_edge(4, 2) tree = dgl.graph(tree.edges()) # init node feature data tree.ndata['x'] = F.zeros((5, 2)) # set all leaf nodes to be ones tree.nodes[[1, 3, 4]].data['x'] = F.ones((3, 2)) dgl.prop_nodes_topo(tree, message_func=mfunc, reduce_func=rfunc, apply_node_func=None) # root node get the sum assert F.allclose(tree.nodes[0].data['x'], F.tensor([[3., 3.]]))
def forward(self, g, h, c): g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) g.register_apply_node_func(self.cell.apply_node_func) device = next(self.parameters()).device self.device = device self.cell.to(self.device) # init of data g.ndata['iou'] = self.cell.W_iou(g.ndata['x']).float().to(self.device) g.ndata['wf_x'] = self.cell.W_f(g.ndata['x']).float().to(self.device) g.ndata['iou_mid'] = torch.zeros(g.number_of_nodes(), 3 * self.h_size).to(self.device) g.ndata['h'] = h.to(self.device) g.ndata['c'] = c.to(self.device) dgl.prop_nodes_topo(g) return g
def forward(self, batch, h, c): """Compute tree-lstm prediction given a batch. Parameters ---------- batch : dgl.data.SSTBatch The data batch. h : Tensor Initial hidden state. c : Tensor Initial cell state. Returns ------- out """ g = batch.graph g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) g.register_apply_node_func(self.cell.apply_node_func) g.ndata['iou'] = self.cell.W_iou(batch.X) g.ndata['h'] = h g.ndata['c'] = c # propagate dgl.prop_nodes_topo(g) # compute logits h = g.ndata.pop('h') # indexes of root nodes head_ids = batch.isroot.nonzero().flatten() # h of root nodes head_h = th.index_select(h, 0, head_ids) lims_ids = head_ids.tolist() + [g.number_of_nodes()] # average of h of non root node by tree inner_h = th.cat([ th.mean(h[s + 1:e - 1, :], dim=0).view(1, -1) for s, e in zip(lims_ids[:-1], lims_ids[1:]) ]) out = th.cat([head_h, inner_h], dim=1) #out = head_h return out
def forward(self, graph: dgl.DGLGraph) -> Tuple[torch.Tensor, torch.Tensor]: x = self.dropout(graph.ndata['x']) for layer in range(self.n_layers): graph.ndata['x'] = x graph = self.cell[layer].init_matrices(graph) dgl.prop_nodes_topo( graph, reduce_func=self.cell[layer].get_reduce_func(), message_func=self.cell[layer].get_message_func(), apply_node_func=self.cell[layer].get_apply_node_func() ) if self.residual: x = self.norm[layer](graph.ndata.pop('h') + x) else: x = graph.ndata.pop('h') c = graph.ndata.pop('c') return x, c
def test_prop_nodes_topo(): # bi-directional chain g = dgl.DGLGraph(nx.path_graph(5)) assert U.check_fail(dgl.prop_nodes_topo, g) # has loop # tree tree = dgl.DGLGraph() tree.add_nodes(5) tree.add_edge(1, 0) tree.add_edge(2, 0) tree.add_edge(3, 2) tree.add_edge(4, 2) tree.register_message_func(mfunc) tree.register_reduce_func(rfunc) # init node feature data tree.ndata['x'] = th.zeros((5, 2)) # set all leaf nodes to be ones tree.nodes[[1, 3, 4]].data['x'] = th.ones((3, 2)) dgl.prop_nodes_topo(tree) # root node get the sum assert U.allclose(tree.nodes[0].data['x'], th.tensor([[3., 3.]]))
def forward(self, batch): g = batch.graph n = g.number_of_nodes() h = torch.zeros((n, self.nhid)) c = torch.zeros((n, self.nhid)) g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) g.register_apply_node_func(self.cell.apply_node_func) embeds = self.embedding(batch.wordid * batch.mask) g.ndata['iou'] = self.cell.W_iou( self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) g.ndata['h'] = h g.ndata['c'] = c dgl.prop_nodes_topo(g) un_g = dgl.unbatch(g) root_list = list() for gh in un_g: root_list.append(gh.nodes[0].data['h']) output = torch.cat(root_list, 0) self.dropout(g.ndata.pop('h')) return output