Esempio n. 1
0
    def forward(self, *t_list):
        out_list = []
        for t in t_list:
            t.set_n_initializer(dgl.init.zero_initializer)

            # apply type module
            if self.type_module is not None:
                type_mask = (t.ndata['t'] != ConstValues.NO_ELEMENT)
                t.ndata['t_embs'] = self.type_module(t.ndata['t'] * type_mask) * type_mask.view(-1, 1)

            # apply input module
            if self.input_module is not None:
                x_mask = (t.ndata['x'] != ConstValues.NO_ELEMENT)
                t.ndata['x_mask'] = x_mask
                t.ndata['x_embs'] = self.input_module(t.ndata['x'] * x_mask) * x_mask.view(-1, 1)

            # propagate
            dgl.prop_nodes_topo(t,
                                message_func=self.cell_module.message_func,
                                reduce_func=self.cell_module.reduce_func,
                                apply_node_func=self.cell_module.apply_node_func)

            # return the hidden
            h = t.ndata['h']

            if self.only_root_state:
                root_ids = [i for i in range(t.number_of_nodes()) if t.out_degree(i) == 0]
                out_list.append(h[root_ids])
            else:
                out_list.append(h)

        if self.output_module is not None:
            return self.output_module(*out_list)
        else:
            return out_list
Esempio n. 2
0
    def forward(self, graph: dgl.DGLGraph) -> torch.Tensor:
        x = self._dropout(graph.ndata["x"])

        # init matrices for message propagation
        number_of_nodes = graph.number_of_nodes()
        graph.ndata["x_iou"] = self._W_iou(x)
        graph.ndata["x_f"] = self._W_f(x)
        graph.ndata["h"] = graph.ndata["x"].new_zeros(
            (number_of_nodes, self._encoder_size))
        graph.ndata["c"] = graph.ndata["x"].new_zeros(
            (number_of_nodes, self._encoder_size))
        graph.ndata["Uh_sum"] = graph.ndata["x"].new_zeros(
            (number_of_nodes, 3 * self._encoder_size))
        graph.ndata["fc_sum"] = graph.ndata["x"].new_zeros(
            (number_of_nodes, self._encoder_size))

        # propagate nodes
        dgl.prop_nodes_topo(
            graph,
            message_func=self.message_func,
            reduce_func=self.reduce_func,
            apply_node_func=self.apply_node_func,
        )

        # [n nodes; encoder size]
        h = graph.ndata.pop("h")
        # [n nodes; decoder size]
        out = self._tanh(self._norm(self._out_linear(h)))
        return out
Esempio n. 3
0
 def forward(self, batch, h, c):
     '''
     Compute tree-lstm prediction given a batch.
     :param batch: dgl.data.SSTBatch
     :param h: Tensor initial hidden state
     :param c: Tensor initial cell state
     :return: logits : Tensor
         The prediction of each node.
     '''
     g = batch.graph
     # to heterogenous graph
     g = dgl.graph(g.edges())
     # feed embedding
     embeds = self.embedding(batch.wordid * batch.mask)
     g.ndata['iou'] = self.cell.W_iou(
         self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
     g.ndata['h'] = h
     g.ndata['c'] = c
     # propagate
     dgl.prop_nodes_topo(g,
                         message_func=self.cell.message_func,
                         reduce_func=self.cell.reduce_func,
                         apply_node_func=self.cell.apply_node_func)
     # compute logits
     h = self.dropout(g.ndata.pop('h'))
     logits = self.linear(h)
     return logits
Esempio n. 4
0
 def forward(self, batch, g, h, c):
     """Compute tree-lstm prediction given a batch.
     Parameters
     ----------
     batch : dgl.data.SSTBatch
         The data batch.
     g : dgl.DGLGraph
         Tree for computation.
     h : Tensor
         Initial hidden state.
     c : Tensor
         Initial cell state.
     Returns
     -------
     logits : Tensor
         The prediction of each node.
     """
     # feed embedding
     embeds = self.embedding(batch['x'].val.squeeze(-1))
     g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch['mask'].float().unsqueeze(-1)
     g.ndata['h'] = h
     g.ndata['c'] = c
     # propagate
     dgl.prop_nodes_topo(g, self.cell.message_func, self.cell.reduce_func, apply_node_func=self.cell.apply_node_func)
     # compute logits
     h = self.dropout(g.ndata.pop('h'))
     logits = self.linear(h)
     return logits
Esempio n. 5
0
    def forward(self, batch, g, h, c):
        """Compute tree-lstm prediction given a batch.

        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.

        Returns
        -------
        out
        """
        g.ndata['iou'] = self.cell.W_iou(batch.X)
        g.ndata['h'] = h
        g.ndata['c'] = c

        dgl.prop_nodes_topo(g,
                            message_func=self.cell.message_func,
                            reduce_func=self.cell.reduce_func,
                            apply_node_func=self.cell.apply_node_func)

        h = g.ndata.pop('h')

        head_ids = th.nonzero(batch.isroot, as_tuple=False).flatten()
        head_h = th.index_select(h, 0, head_ids)

        out = head_h
        return out
Esempio n. 6
0
    def forward(self, batch, enc_hidden, list_root_index, list_num_node):

        g = batch.graph
        g.register_message_func(self.cell.message_func)
        g.register_reduce_func(self.cell.reduce_func)
        g.register_apply_node_func(self.cell.apply_node_func)

        g.ndata['h'] = enc_hidden[0]
        g.ndata['c'] = enc_hidden[1]
        wemb = self.wemb(batch.wordid * batch.mask)
        g.ndata['iou'] = self.cell.W_iou(wemb) * batch.mask.float().unsqueeze(-1)

        dgl.prop_nodes_topo(g)

        all_node_h_in_batch = g.ndata.pop('h')
        all_node_c_in_batch = g.ndata.pop('c')

        if self.opt.tree_lstm_output_type == "root_node":
            root_node_h_in_batch, root_node_c_in_batch = [], []
            add_up_num_node = 0
            for _i in range(len(list_root_index)):
                if _i - 1 < 0:
                    add_up_num_node = 0
                else:
                    add_up_num_node += list_num_node[_i - 1]
                idx_to_query = list_root_index[_i] + add_up_num_node
                root_node_h_in_batch.append(all_node_h_in_batch[idx_to_query])
                root_node_c_in_batch.append(all_node_c_in_batch[idx_to_query])

            root_node_h_in_batch = torch.cat(root_node_h_in_batch).reshape(1, len(root_node_h_in_batch), -1)
            root_node_c_in_batch = torch.cat(root_node_c_in_batch).reshape(1, len(root_node_c_in_batch), -1)

            return root_node_h_in_batch, root_node_c_in_batch
        elif self.opt.tree_lstm_output_type == "no_reduce":
            return all_node_h_in_batch, all_node_c_in_batch
Esempio n. 7
0
 def forward(self, g, x, mask):
     """Compute tree-lstm prediction given a batch.
     Parameters
     ----------
     batch : TreeDataset.TreeBatch
         The data batch.
     h : Tensor
         Initial hidden state.
     c : Tensor
         Initial cell state.
     Returns
     -------
     logits : Tensor
         The prediction of each node.
     """
     g.register_message_func(self.cell.message_func)
     g.register_reduce_func(self.cell.reduce_func)
     g.register_apply_node_func(self.cell.apply_node_func)
     # feed embedding
     # TODO: x * mask does not make sense. use x[mask]
     embeds = self.input_module(x * mask)
     g.ndata['iou_input'] = self.W_iou(embeds) * mask.float().unsqueeze(-1)
     g.ndata['f_input'] = self.W_f(embeds) * mask.float().unsqueeze(-1)
     # propagate
     dgl.prop_nodes_topo(g)
     # compute output
     h = g.ndata['h']
     out = self.output_module(h)
     return out
Esempio n. 8
0
 def forward(self, batch, h, c):
     """Compute tree-lstm prediction given a batch.
     Parameters
     ----------
     batch : dgl.data.SSTBatch
         The data batch.
     h : Tensor
         Initial hidden state.
     c : Tensor
         Initial cell state.
     Returns
     -------
     logits : Tensor
         The prediction of each node.
     """
     g = batch.graph
     # feed embedding
     embeds = self.embedding(batch.wordid * batch.mask)
     wiou = self.cell.W_iou(self.dropout(embeds))
     g.ndata['iou'] = wiou * batch.mask.expand_dims(-1).astype(wiou.dtype)
     g.ndata['h'] = h
     g.ndata['c'] = c
     # propagate
     dgl.prop_nodes_topo(g,
                         message_func=self.cell.message_func,
                         reduce_func=self.cell.reduce_func,
                         apply_node_func=self.cell.apply_node_func)
     # compute logits
     h = self.dropout(g.ndata.pop('h'))
     logits = self.linear(h)
     return logits
Esempio n. 9
0
    def forward(self, batch, h, c):
        """Compute tree-lstm prediction given a batch.

        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.

        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        g = batch.graph
        g.register_message_func(self.cell.message_func)
        g.register_reduce_func(self.cell.reduce_func)
        g.register_apply_node_func(self.cell.apply_node_func)
        # feed embedding
        embeds = self.embedding(batch.wordid * batch.mask)
        g.ndata['iou'] = self.cell.W_iou(
            self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
        g.ndata['h'] = h
        g.ndata['c'] = c
        # propagate
        dgl.prop_nodes_topo(g)
        # compute logits
        h = self.dropout(g.ndata.pop('h'))
        logits = self.linear(h)
        return logits
Esempio n. 10
0
    def forward(self, src, enc_tree, dec_tree, return_code=False, **kwargs):
        self._init_graph(enc_tree, dec_tree)
        # Soruce encoding
        src_mask = torch.ne(src, 0).float()
        encoder_outputs = self.encode_source(src, src_mask, meanpool=True)
        encoder_states = encoder_outputs["encoder_states"]
        # Tree encoding
        enc_x = enc_tree.ndata["x"].cuda()
        x_embeds = self.label_embed_layer(enc_x)
        enc_tree.ndata['iou'] = self.enc_cell.W_iou(self.dropout(x_embeds))
        enc_tree.ndata['h'] = torch.zeros(
            (enc_tree.number_of_nodes(), self.hidden_size)).cuda()
        enc_tree.ndata['c'] = torch.zeros(
            (enc_tree.number_of_nodes(), self.hidden_size)).cuda()
        enc_tree.ndata['mask'] = enc_tree.ndata['mask'].float().cuda()
        dgl.prop_nodes_topo(enc_tree)
        # Obtain root representation
        root_mask = enc_tree.ndata["mask"].float().cuda()
        # root_idx = torch.arange(root_mask.shape[0])[root_mask > 0].cuda()
        root_h = self.dropout(
            enc_tree.ndata.pop("h")) * root_mask.unsqueeze(-1)
        orig_h = root_h.clone()[root_mask > 0]
        partial_h = orig_h
        if self._without_source:
            partial_h += encoder_states

        # Discretization
        if self._code_bits > 0:
            if return_code:
                codes = self.semhash(partial_h, return_code=True)
                ret = {"codes": codes}
                return ret
            else:
                partial_h = self.semhash(partial_h)
            if not self._without_source:
                partial_h += encoder_states

        root_h[root_mask > 0] = partial_h
        # Tree decoding
        dec_x = dec_tree.ndata["x"].cuda()
        dec_embeds = self.label_embed_layer(dec_x)
        dec_tree.ndata['iou'] = self.dec_cell.W_iou(self.dropout(dec_embeds))
        dec_tree.ndata['h'] = root_h
        dec_tree.ndata['c'] = torch.zeros(
            (enc_tree.number_of_nodes(), self.hidden_size)).cuda()
        dec_tree.ndata['mask'] = dec_tree.ndata['mask'].float().cuda()
        dgl.prop_nodes_topo(dec_tree)
        # Compute logits
        all_h = self.dropout(dec_tree.ndata.pop("h"))
        logits = self.logit_nn(all_h)
        logp = F.log_softmax(logits, 1)
        # Compute loss
        y_labels = dec_tree.ndata["y"].cuda()
        monitor = {}
        loss = F.nll_loss(logp, y_labels, reduction="mean")
        acc = (logits.argmax(1) == y_labels).float().mean()
        monitor["loss"] = loss
        monitor["label_accuracy"] = acc
        return monitor
    def forward(self, batch):
        """Compute tree-lstm prediction given a batch.

        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.

        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        #----------utils function---------------
        def InitS(tree):
            tree.ndata['s'] = tree.ndata['e'].mean(dim=0).repeat(tree.number_of_nodes(), 1)
            return tree

        def updateS(tree, state):
            assert state.dim() == 1
            tree.ndata['s'] = state.repeat(tree.number_of_nodes(), 1)
            return tree

        def extractS(batchTree):
            # [dmodel] --> [[dmodel]] --> [tree, dmodel] --> [tree, 1, dmodel]
            s_list = [tree.ndata.pop('s')[0].unsqueeze(0) for tree in dgl.unbatch(batchTree)]
            return th.cat(s_list, dim=0).unsqueeze(1)

        def extractH(batchTree):
            # [nodes, dmodel] --> [nodes, dmodel]--> [max_nodes, dmodel]--> [tree*_max_nodes, dmodel] --> [tree, max_nodes, dmodel]
            h_list = [tree.ndata.pop('h') for tree in dgl.unbatch(batchTree)]
            max_nodes = max([h.size(0) for h in h_list])
            h_list = [th.cat([h, th.zeros([max_nodes-h.size(0), h.size(1)]).to(self.device)], dim=0).unsqueeze(0) for h in h_list]
            return th.cat(h_list, dim=0)
        #-----------------------------------------

        g = batch.graph
        # feed embedding
        embeds = self.embedding(batch.wordid * batch.mask)
        g.ndata['c'] = th.zeros((g.number_of_nodes(), 2, self.dmodel)).to(self.device)
        g.ndata['e'] = embeds*batch.mask.float().unsqueeze(-1)
        g.ndata['h'] = embeds*batch.mask.float().unsqueeze(-1)
        g = dgl.batch([InitS(gg) for gg in dgl.unbatch(g)])
        # propagate
        for i in range(self.T_step):
            g.register_message_func(self.cell.message_func)
            g.register_reduce_func(self.cell.reduce_func)
            g.register_apply_node_func(self.cell.apply_node_func)
            dgl.prop_nodes_topo(g)
            States = self.cell.updateGlobalVec(extractS(g), extractH(g) )
            g = dgl.batch([updateS(tree, state) for (tree, state) in zip(dgl.unbatch(g), States)])
        # compute logits
        h = self.dropout(g.ndata.pop('h'))
        logits = self.linear(h)
        return logits
Esempio n. 12
0
    def propagate(self, g, cell, X, h, c):
        g.ndata['iou'] = cell.W_iou(X)
        g.ndata['h'] = h
        g.ndata['c'] = c

        dgl.prop_nodes_topo(g,
                            message_func=cell.message_func,
                            reduce_func=cell.reduce_func,
                            apply_node_func=cell.apply_node_func)

        return g.ndata.pop('h')
Esempio n. 13
0
def run(cell, graph, iou, h, c):
    g = graph
    g.register_message_func(cell.message_func)
    g.register_reduce_func(cell.reduce_func)
    g.register_apply_node_func(cell.apply_node_func)
    # feed embedding
    g.ndata["iou"] = iou
    g.ndata["h"] = h
    g.ndata["c"] = c
    # propagate
    dgl.prop_nodes_topo(g)
    return g.ndata.pop("h")
Esempio n. 14
0
    def propagate(self, g, cell, X, h, c):
        g.register_message_func(cell.message_func)
        g.register_reduce_func(cell.reduce_func)
        g.register_apply_node_func(cell.apply_node_func)
        #g.ndata['iou'] = cell.d(cell.W_iou(X))
        g.ndata['iou'] = cell.W_iou(X)
        g.ndata['h'] = h
        g.ndata['c'] = c
        # propagate
        dgl.prop_nodes_topo(g)

        return g.ndata.pop('h')
Esempio n. 15
0
 def forward(self, batch: BatchedTree):
     batches = [deepcopy(batch) for _ in range(self.num_stacks)]
     for stack in range(self.num_stacks):
         cur_batch = batches[stack]
         if stack > 0:
             prev_batch = batches[stack - 1]
             cur_batch.batch_dgl_graph.ndata['x'] = prev_batch.batch_dgl_graph.ndata['h']
         cur_batch.batch_dgl_graph.register_message_func(self.cell.message_func)
         cur_batch.batch_dgl_graph.register_reduce_func(self.cell.reduce_func)
         cur_batch.batch_dgl_graph.register_apply_node_func(self.cell.apply_node_func)
         cur_batch.batch_dgl_graph.ndata['ruo'] = self.cell.W_ruo(self.dropout(batch.batch_dgl_graph.ndata['x']))
         dgl.prop_nodes_topo(cur_batch.batch_dgl_graph)
     return batches
Esempio n. 16
0
    def forward(self, rnn_inputs: th.Tensor, graph: dgl.DGLGraph):
        g = graph
        g.register_message_func(self.tree_lstm.message_func)
        g.register_reduce_func(self.tree_lstm.reduce_func)
        g.register_apply_node_func(self.tree_lstm.apply_node_func)
        g.ndata['iou'] = self.tree_lstm.W_iou(rnn_inputs)
        g.ndata['h'] = self.linear(rnn_inputs)
        g.ndata['c'] = th.autograd.Variable(th.zeros(
            (graph.number_of_nodes(), self.output_dim), dtype=th.float32),
                                            requires_grad=False).to(self.device)
        dgl.prop_nodes_topo(g)

        return g.ndata.pop('h')
    def forward(self, graph: dgl.DGLGraph) -> torch.Tensor:
        """Apply transformer encoder

        :param graph: batched dgl graph
        :return: encoded nodes [number of nodes, hidden size]
        """
        graph.ndata['h'] = graph.ndata['x'].new_zeros(
            (graph.number_of_nodes(), self.h_enc))
        dgl.prop_nodes_topo(graph,
                            message_func=[dgl.function.copy_u('h', 'h')],
                            reduce_func=self.reduce_func,
                            apply_node_func=self.apply_node_func)
        return graph.ndata.pop('h')
Esempio n. 18
0
    def forward(self, batch):
        g = batch.graph
        # Set the function defined to be the default message function
        g.register_message_func(self.cell.message_func)
        g.register_reduce_func(self.cell.reduce_func)
        
        embeds = self.embedding(batch.wordid * batch.mask)

        g.ndata['p'] = self.dropout(embeds)  # add dropout --> improve BestRoot_acc_test
        # Define traversal:trigger the message passing
        dgl.prop_nodes_topo(g)  # Message propagation using node frontiers generated by topological order. 
        p = self.dropout(g.ndata.pop('p')) # pop():get and remove node states from the graph-->save memory
        # compute logits
        logits = self.W_label(p)
        return logits
Esempio n. 19
0
    def forward(self, batch, h, c):
        """Compute tree-lstm prediction given a batch.
        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.
        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        g = batch.graph
        g.register_message_func(self.cell.message_func)
        g.register_reduce_func(self.cell.reduce_func)
        g.register_apply_node_func(self.cell.apply_node_func)
        # feed embedding
        # embeds = self.embedding(batch.wordid)
        # g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds))
        g.ndata['iou'] = self.cell.W_iou(batch.wordid)
        g.ndata['h'] = h
        g.ndata['c'] = c
        # propagate
        dgl.prop_nodes_topo(g)
        # compute logits
        h = self.dropout(g.ndata.pop('h'))
        # logits = self.linear(h)

        # attention part
        if self.attention:
            root_node_h_in_batch = []
            root_idx = []
            result_idx = 0
            for idx in g.batch_num_nodes:
                root_node_h_in_batch.append(h[result_idx])
                root_idx.append(result_idx)
                result_idx = result_idx + idx

            root_node_h_in_batch = th.cat(root_node_h_in_batch).reshape(len(root_idx), -1)
            node_num = th.tensor(batch.graph.batch_num_nodes)
            node_num = node_num.to(root_node_h_in_batch.device)
            h = self.tree_attn(h, root_node_h_in_batch, node_num)
        h = F.relu(self.wh(h))
        logits = self.linear(h)
        return logits
Esempio n. 20
0
 def forward(self, batch, h, c):
     g = batch.graph
     g.register_message_func(self.cell.message_func)
     g.register_reduce_func(self.cell.reduce_func)
     g.register_apply_node_func(self.cell.apply_node_func)
     # feed embedding
     embeds = self.embedding(batch.wordid * batch.mask)
     g.ndata['iou'] = self.cell.W_iou(
         self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
     g.ndata['h'] = h
     g.ndata['c'] = c
     # propagate
     dgl.prop_nodes_topo(g)
     # compute logits
     h = self.dropout(g.ndata.pop('h'))
     logits = self.linear(h)
     return logits
    def forward(self, graph: dgl.BatchedDGLGraph, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        number_of_nodes = graph.number_of_nodes()
        graph.ndata['x_iou'] = self.W_iou(graph.ndata['x']) + self.b_iou
        graph.ndata['x_f'] = self.W_f(graph.ndata['x']) + self.b_f
        graph.ndata['h'] = torch.zeros((number_of_nodes, self.h_size), device=device)
        graph.ndata['c'] = torch.zeros((number_of_nodes, self.h_size), device=device)
        graph.ndata['Uh_sum'] = torch.zeros((number_of_nodes, 3 * self.h_size), device=device)
        graph.ndata['fc_sum'] = torch.zeros((number_of_nodes, self.h_size), device=device)

        graph.register_message_func(self.message_func)
        graph.register_apply_node_func(self.apply_node_func)

        dgl.prop_nodes_topo(graph, reduce_func=[fn.sum('Uh', 'Uh_sum'), fn.sum('fc', 'fc_sum')])

        h = graph.ndata.pop('h')
        c = graph.ndata.pop('c')
        return h, c
    def forward(self, graph: dgl.BatchedDGLGraph, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        # register function for message passing
        graph.register_reduce_func(self.reduce_func)
        graph.register_apply_node_func(self.apply_node_func)

        features = graph.ndata['x']
        nodes_in_batch = graph.number_of_nodes()
        graph.ndata['node_iou'] = self.W_iou(features) + self.b_iou
        graph.ndata['node_f'] = self.W_f(features) + self.b_f
        graph.ndata['h'] = torch.zeros(nodes_in_batch, self.h_size).to(device)
        graph.ndata['c'] = torch.zeros(nodes_in_batch, self.h_size).to(device)
        graph.ndata['Uh_tilda'] = torch.zeros(nodes_in_batch, 3 * self.h_size).to(device)
        # propagate
        dgl.prop_nodes_topo(graph, message_func=[fn.copy_u('h', 'h'), fn.copy_u('c', 'c')])
        # get encoded output
        h = graph.ndata.pop('h')
        c = graph.ndata.pop('c')
        return h, c
Esempio n. 23
0
    def forward(self, batch, h, c):
        g = batch.graph
        g.register_message_func(self.cell.message_func)
        g.register_reduce_func(self.cell.reduce_func)
        g.register_apply_node_func(self.cell.apply_node_func)
        # (num_nodes, 256)
        embeds = self.embedding(batch.wordid * batch.mask)
        g.ndata['iou'] = self.cell.W_iou(
            self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
        g.ndata['h'] = h
        g.ndata['c'] = c
        dgl.prop_nodes_topo(g)

        x = unbatch(g)
        y = x[0].nodes[0].data['h']
        h = self.dropout(g.ndata.pop('h'))

        logits = self.linear(h)
        return logits
Esempio n. 24
0
def test_prop_nodes_topo():
    # bi-directional chain
    g = dgl.DGLGraph(nx.path_graph(5))

    # tree
    tree = dgl.DGLGraph()
    tree.add_nodes(5)
    tree.add_edge(1, 0)
    tree.add_edge(2, 0)
    tree.add_edge(3, 2)
    tree.add_edge(4, 2)
    tree.register_message_func(mfunc)
    tree.register_reduce_func(rfunc)
    # init node feature data
    tree.ndata['x'] = mx.nd.zeros(shape=(5, 2))
    # set all leaf nodes to be ones
    tree.nodes[[1, 3, 4]].data['x'] = mx.nd.ones(shape=(3, 2))
    dgl.prop_nodes_topo(tree)
    # root node get the sum
    assert np.allclose(tree.nodes[0].data['x'].asnumpy(), np.array([[3., 3.]]))
Esempio n. 25
0
def test_prop_nodes_topo(idtype):
    # bi-directional chain
    g = dgl.graph(nx.path_graph(5), idtype=idtype, device=F.ctx())
    assert U.check_fail(dgl.prop_nodes_topo, g)  # has loop

    # tree
    tree = dgl.DGLGraph()
    tree.add_nodes(5)
    tree.add_edge(1, 0)
    tree.add_edge(2, 0)
    tree.add_edge(3, 2)
    tree.add_edge(4, 2)
    tree = dgl.graph(tree.edges())
    # init node feature data
    tree.ndata['x'] = F.zeros((5, 2))
    # set all leaf nodes to be ones
    tree.nodes[[1, 3, 4]].data['x'] = F.ones((3, 2))
    dgl.prop_nodes_topo(tree, message_func=mfunc, reduce_func=rfunc, apply_node_func=None)
    # root node get the sum
    assert F.allclose(tree.nodes[0].data['x'], F.tensor([[3., 3.]]))
Esempio n. 26
0
    def forward(self, g, h, c):
        g.register_message_func(self.cell.message_func)
        g.register_reduce_func(self.cell.reduce_func)
        g.register_apply_node_func(self.cell.apply_node_func)

        device = next(self.parameters()).device
        self.device = device
        self.cell.to(self.device)

        # init of data
        g.ndata['iou'] = self.cell.W_iou(g.ndata['x']).float().to(self.device)
        g.ndata['wf_x'] = self.cell.W_f(g.ndata['x']).float().to(self.device)
        g.ndata['iou_mid'] = torch.zeros(g.number_of_nodes(),
                                         3 * self.h_size).to(self.device)
        g.ndata['h'] = h.to(self.device)
        g.ndata['c'] = c.to(self.device)

        dgl.prop_nodes_topo(g)

        return g
Esempio n. 27
0
    def forward(self, batch, h, c):
        """Compute tree-lstm prediction given a batch.

        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.

        Returns
        -------
        out
        """
        g = batch.graph
        g.register_message_func(self.cell.message_func)
        g.register_reduce_func(self.cell.reduce_func)
        g.register_apply_node_func(self.cell.apply_node_func)
        g.ndata['iou'] = self.cell.W_iou(batch.X)
        g.ndata['h'] = h
        g.ndata['c'] = c
        # propagate
        dgl.prop_nodes_topo(g)
        # compute logits
        h = g.ndata.pop('h')

        # indexes of root nodes
        head_ids = batch.isroot.nonzero().flatten()
        # h of root nodes
        head_h = th.index_select(h, 0, head_ids)
        lims_ids = head_ids.tolist() + [g.number_of_nodes()]
        # average of h of non root node by tree
        inner_h = th.cat([
            th.mean(h[s + 1:e - 1, :], dim=0).view(1, -1)
            for s, e in zip(lims_ids[:-1], lims_ids[1:])
        ])
        out = th.cat([head_h, inner_h], dim=1)
        #out = head_h
        return out
Esempio n. 28
0
    def forward(self, graph: dgl.DGLGraph) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.dropout(graph.ndata['x'])

        for layer in range(self.n_layers):
            graph.ndata['x'] = x

            graph = self.cell[layer].init_matrices(graph)
            dgl.prop_nodes_topo(
                graph,
                reduce_func=self.cell[layer].get_reduce_func(),
                message_func=self.cell[layer].get_message_func(),
                apply_node_func=self.cell[layer].get_apply_node_func()
            )

            if self.residual:
                x = self.norm[layer](graph.ndata.pop('h') + x)
            else:
                x = graph.ndata.pop('h')

        c = graph.ndata.pop('c')
        return x, c
Esempio n. 29
0
def test_prop_nodes_topo():
    # bi-directional chain
    g = dgl.DGLGraph(nx.path_graph(5))
    assert U.check_fail(dgl.prop_nodes_topo, g)  # has loop

    # tree
    tree = dgl.DGLGraph()
    tree.add_nodes(5)
    tree.add_edge(1, 0)
    tree.add_edge(2, 0)
    tree.add_edge(3, 2)
    tree.add_edge(4, 2)
    tree.register_message_func(mfunc)
    tree.register_reduce_func(rfunc)
    # init node feature data
    tree.ndata['x'] = th.zeros((5, 2))
    # set all leaf nodes to be ones
    tree.nodes[[1, 3, 4]].data['x'] = th.ones((3, 2))
    dgl.prop_nodes_topo(tree)
    # root node get the sum
    assert U.allclose(tree.nodes[0].data['x'], th.tensor([[3., 3.]]))
Esempio n. 30
0
    def forward(self, batch):
        g = batch.graph
        n = g.number_of_nodes()
        h = torch.zeros((n, self.nhid))
        c = torch.zeros((n, self.nhid))

        g.register_message_func(self.cell.message_func)
        g.register_reduce_func(self.cell.reduce_func)
        g.register_apply_node_func(self.cell.apply_node_func)
        embeds = self.embedding(batch.wordid * batch.mask)
        g.ndata['iou'] = self.cell.W_iou(
            self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
        g.ndata['h'] = h
        g.ndata['c'] = c
        dgl.prop_nodes_topo(g)
        un_g = dgl.unbatch(g)
        root_list = list()
        for gh in un_g:
            root_list.append(gh.nodes[0].data['h'])
        output = torch.cat(root_list, 0)
        self.dropout(g.ndata.pop('h'))
        return output