Esempio n. 1
0
def test_topological_nodes(n=100):
    g = dgl.DGLGraph()
    a = sp.random(n, n, 3 / n, data_rvs=lambda n: np.ones(n))
    b = sp.tril(a, -1).tocoo()
    g.from_scipy_sparse_matrix(b)

    layers_dgl = dgl.topological_nodes_generator(g)

    adjmat = g.adjacency_matrix()
    def tensor_topo_traverse():
        n = g.number_of_nodes()
        mask = F.copy_to(F.ones((n, 1)), F.cpu())
        degree = F.spmm(adjmat, mask)
        while F.reduce_sum(mask) != 0.:
            v = F.astype((degree == 0.), F.float32)
            v = v * mask
            mask = mask - v
            frontier = F.copy_to(F.nonzero_1d(F.squeeze(v, 1)), F.cpu())
            yield frontier
            degree -= F.spmm(adjmat, v)

    layers_spmv = list(tensor_topo_traverse())

    assert len(layers_dgl) == len(layers_spmv)
    assert all(toset(x) == toset(y) for x, y in zip(layers_dgl, layers_spmv))
Esempio n. 2
0
def test_topological_nodes(n=1000):
    g = dgl.DGLGraph()
    a = sp.random(n, n, 10 / n, data_rvs=lambda n: np.ones(n))
    b = sp.tril(a, -1).tocoo()
    g.from_scipy_sparse_matrix(b)

    layers_dgl = dgl.topological_nodes_generator(g)

    adjmat = g.adjacency_matrix()
    def tensor_topo_traverse():
        n = g.number_of_nodes()
        mask = mx.nd.ones(shape=(n, 1))
        degree = mx.nd.dot(adjmat, mask)
        while mx.nd.sum(mask) != 0.:
            v = (degree == 0.).astype(np.float32)
            v = v * mask
            mask = mask - v
            tmp = np.nonzero(mx.nd.squeeze(v).asnumpy())[0]
            frontier = mx.nd.array(tmp, dtype=tmp.dtype)
            yield frontier
            degree -= mx.nd.dot(adjmat, v)

    layers_spmv = list(tensor_topo_traverse())

    assert len(layers_dgl) == len(layers_spmv)
    assert all(toset(x) == toset(y) for x, y in zip(layers_dgl, layers_spmv))
 def _is_suitable_tree(self, tree: dgl.DGLGraph) -> bool:
     if self._config.max_tree_nodes is not None and tree.number_of_nodes(
     ) > self._config.max_tree_nodes:
         return False
     if (self._config.max_tree_depth is not None
             and len(dgl.topological_nodes_generator(tree)) >
             self._config.max_tree_depth):
         return False
     return True
Esempio n. 4
0
def get_root_node_info(dgl_trees: Union[Tuple, List]) -> Tuple:
    root_indices, node_nums = [None] * len(dgl_trees), [None] * len(dgl_trees)
    for ind, tree in enumerate(dgl_trees):
        topological_nodes = dgl.topological_nodes_generator(tree)
        root_ind_tree_dgldigraph = topological_nodes[-1].item()
        root_indices[ind] = root_ind_tree_dgldigraph
        all_num_node_tree_dgldigraph = tree.number_of_nodes()
        node_nums[ind] = all_num_node_tree_dgldigraph
    root_indices = np.array(root_indices)
    num_nodes = np.array(node_nums)
    return root_indices, num_nodes,
Esempio n. 5
0
def _get_root_node_info(batch_list):
    list_root_index, list_num_node = [], []

    for tree_dgldigraph in batch_list:
        topological_nodes_list = dgl.topological_nodes_generator(
            tree_dgldigraph)
        root_id_tree_dgldigraph = topological_nodes_list[-1].item()
        list_root_index.append(root_id_tree_dgldigraph)
        all_num_node_tree_dgldigraph = tree_dgldigraph.number_of_nodes()
        list_num_node.append(all_num_node_tree_dgldigraph)

    root_index_np = np.array(list_root_index)
    num_node_np = np.array(list_num_node)

    return root_index_np, num_node_np
Esempio n. 6
0
    def forward(self, agg_graph: dgl.DGLGraph, prop_graph: dgl.DGLGraph,
                new_node_ids: list) -> torch.Tensor:
        tg = agg_graph.local_var()
        pg = prop_graph.local_var()
        torder = dgl.topological_nodes_generator(prop_graph)
        torder = tuple([t.to(pg.device) for t in torder])

        feats = []
        for i, layer in enumerate(self.layers):
            feat = layer(tg, pg, torder, new_node_ids)
            # print("Layer %d" % i)
            # feat = self.act(self.dropout(feat))
            # tg.ndata["nfeat"] = pg.ndata["nfeat"] = feat
            feats.append(self.act(self.dropout(feat)))
            tg.ndata["nfeat"] = pg.ndata["nfeat"] = feats[-1]
        # return feat
        return feats[-1]
Esempio n. 7
0
    def forward(self, graph: dgl.DGLGraph) -> torch.Tensor:
        """Forward pass for positional embedding

        @param graph: a batched graph with oriented edges from leaves to roots
        @return: positional embedding [n_nodes, n * k]
        """
        pos_embeds = graph.ndata['x'].new_zeros((graph.number_of_nodes(), self.h_emb))
        for layer in dgl.topological_nodes_generator(graph, reverse=True):
            for node in layer:
                children = graph.in_edges(node, form='uv')[0]
                pos_embeds[children, self.n:] = pos_embeds[node, :-self.n]
                eye_tensor = graph.ndata["x"].new_zeros((children.shape[0], self.n))
                diag_range = torch.arange(0, min(children.shape[0], self.n), dtype=torch.long)
                eye_tensor[diag_range, diag_range] = 1
                pos_embeds[children, :self.n] = eye_tensor
        # TODO: implement parametrized positional embedding with using p
        return pos_embeds
Esempio n. 8
0
    def forward(self, embs, parents, rels, mask=None):
        # build dgl graph

        # add nodes
        g = dgl.DGLGraph()
        g.add_nodes(embs.size(0) * embs.size(1))

        if self.td_cells is not None:
            g2 = dgl.DGLGraph()
            g2.add_nodes(embs.size(0) * embs.size(1))

        # add edges
        maxtime = embs.size(1)
        for i in range(len(embs)):
            for t in range(embs.size(1)):
                if parents[i, t].item() != -1:
                    g.add_edge(
                        i * maxtime + t, i * maxtime + parents[i, t].item(), {
                            "relid": rels[i, t][None],
                            "x": self.rel_emb(rels[i, t])[None]
                        })
                    if self.td_cells is not None:
                        g2.add_edge(
                            i * maxtime + parents[i, t].item(),
                            i * maxtime + t, {
                                "relid": rels[i, t][None],
                                "x": self.rev_rel_emb(rels[i, t])[None]
                            })

        states = embs

        for i in range(len(self.bu_cells)):
            bu_cell = self.bu_cells[i]
            td_cell = self.td_cells[i] if self.td_cells is not None else None
            g.ndata["x"] = states.view(-1, states.size(-1))
            g.ndata["red"] = torch.zeros(g.ndata["x"].size(0),
                                         self.hdim,
                                         device=states.device)

            if td_cell is not None:
                g2.ndata["x"] = states.view(-1, states.size(-1))
                g2.ndata["red"] = torch.zeros(g.ndata["x"].size(0),
                                              self.hdim,
                                              device=states.device)

            g_traversal_order = dgl.topological_nodes_generator(g)
            g.prop_nodes(g_traversal_order,
                         message_func=bu_cell.message_func,
                         reduce_func=bu_cell.reduce_func,
                         apply_node_func=bu_cell.apply_node_func)

            if td_cell is not None:
                g2_traversal_order = dgl.topological_nodes_generator(g2)
                g2.prop_nodes(g2_traversal_order,
                              message_func=td_cell.message_func,
                              reduce_func=td_cell.reduce_func,
                              apply_node_func=td_cell.apply_node_func)

            bu_states = g.ndata["h"].view(states.size(0), states.size(1), -1)
            states = bu_states
            if td_cell is not None:
                td_states = g2.ndata["h"].view(states.size(0), states.size(1),
                                               -1)
                states = torch.cat([bu_states, td_states], 2)
        return states
Esempio n. 9
0
#
# In the case of Tree-LSTM, messages start from leaves of the tree, and
# propagate/processed upwards until they reach the roots. A visualization
# is as follows:
#
# .. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif
#    :alt:
#
# DGL defines a generator to perform the topological sort, each item is a
# tensor recording the nodes from bottom level to the roots. One can
# appreciate the degree of parallelism by inspecting the difference of the
# followings:
#

print('Traversing one tree:')
print(dgl.topological_nodes_generator(a_tree))

print('Traversing many trees at the same time:')
print(dgl.topological_nodes_generator(graph))

##############################################################################
# We then call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:

import dgl.function as fn
import torch as th

graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)
graph.register_message_func(fn.copy_src('a', 'a'))
graph.register_reduce_func(fn.sum('a', 'a'))

traversal_order = dgl.topological_nodes_generator(graph)
Esempio n. 10
0
def construct_prediction(config,graph):
    model = config.model
    opt = config.opt
    DGL_input = opt.DGL_input
    PYG_input = opt.PYG_input
    input_size = opt.max_prev_node
    
    # graph = params["graph"]
    
    graph_len = len(graph)
    if DGL_input == True: # "GraphLSTM_dgl"
        graphlist1 = []
        graphlist2 = []
        y_input = np.zeros((graph_len,opt.max_num_node,2))
        len_input = np.zeros((graph_len))   
        gnum = 0
        
        device = next(model.parameters()).device # check the device that the model is running on
        for g in graph:
            len_node = g["len"]
            len_input[gnum] = len_node
            nodenum = g["len"]
            g1 = g["g1"]
            g1_x = Variable(torch.from_numpy(g["x"][0:nodenum,:])).float().to(device)
            g1.ndata["x"] =g1_x
            g1.edata['edge_label'] = g1.edata['edge_label'].to(device)
            g2 = g["g2"]
            g2_x = Variable(torch.from_numpy(g["x"][0:nodenum,:])).float().to(device)
            g2.ndata["x"] = g2_x
            g2.edata['edge_label'] = g2.edata['edge_label'].to(device)
            ###
            graphlist1.append(g1)
            graphlist2.append(g2)
            y_input[gnum,:,:] = g["pos"]
            gnum = gnum + 1
        ### Variable and cuda
        y = torch.from_numpy(y_input).float().to(device)
        ### Use model to predict coordinates
        y_pred = model(graphlist1,graphlist2)
    elif PYG_input == True: # "GraphLSTM_pyg"
        graphlist1 = []
        graphlist2 = []
        graphlist1_dgl = []
        graphlist2_dgl = []
        y_input = np.zeros((graph_len,opt.max_num_node,2))
        len_input = np.zeros((graph_len)) 
        gnum = 0

        # check the device that the model is running on
        device = next(model.parameters()).device
        accu_count = 0
        from torch_geometric.data import Data
        from torch_geometric.data import Batch
        for g in graph:
            len_node = g["len"]
            len_input[gnum] = len_node
            nodenum = g["len"]
            g_x = Variable(torch.from_numpy(g["x"][0:nodenum,:])).float()
            g1_edge_index = torch.from_numpy(g["g1_edge_index"]).long()
            g1_edge_label = torch.from_numpy(g["g1_edge_label"]).float()

            g2_edge_index = torch.from_numpy(g["g2_edge_index"]).long()
            g2_edge_label = torch.from_numpy(g["g2_edge_label"]).float()
            g1_data = Data(x=g_x,edge_index=g1_edge_index,edge_attr=g1_edge_label)#.to(device)
            g2_data = Data(x=g_x,edge_index=g2_edge_index,edge_attr=g2_edge_label)#.to(device)
            graphlist1_dgl.append(g["g1"])
            graphlist2_dgl.append(g["g2"])
   
            graphlist1.append(g1_data)
            graphlist2.append(g2_data)
            y_input[gnum,:,:] = g["pos"]
            accu_count = accu_count + nodenum
            gnum = gnum + 1
        # Variable and cuda
        y = torch.from_numpy(y_input).float().to(device)
        len_input = torch.from_numpy(len_input).long().to(device)

        ### Use the trained model to predict coordinates
        g1_batch = Batch.from_data_list(graphlist1)#.to(device)
        g2_batch = Batch.from_data_list(graphlist2)#.to(device)
        g1_dgl_batch = dgl.batch(graphlist1_dgl)
        g2_dgl_batch = dgl.batch(graphlist2_dgl)
        g1_order = dgl.topological_nodes_generator(g1_dgl_batch)
        g2_order = dgl.topological_nodes_generator(g2_dgl_batch)
        g1_order_mask = np.zeros((len(g1_order),accu_count))
        g2_order_mask = np.zeros((len(g2_order),accu_count))
        g1_edge_index = g1_batch.edge_index
        g2_edge_index = g2_batch.edge_index
        g1_edge_order_mask_list = []
        g2_edge_order_mask_list = []
        for i in range(len(g1_order)):
            order = g1_order[i]
            g1_order_mask[i,order]=1
            mask_index = g1_order_mask[i,g1_edge_index[0]]
            mask_index = np.nonzero(mask_index)    
            g1_edge_order_mask_list.append(mask_index[0])
        for i in range(len(g2_order)):
            order = g2_order[i]
            g2_order_mask[i,order]=1
            mask_index = g2_order_mask[i,g2_edge_index[0]]
            mask_index = np.nonzero(mask_index)
            g2_edge_order_mask_list.append(mask_index[0])
        g1_order = [order.to(device) for order in g1_order]
        g2_order = [order.to(device) for order in g2_order]
        g1_edge_order_mask_list = [torch.from_numpy(edge_mask).long().to(device) for edge_mask in g1_edge_order_mask_list]
        g2_edge_order_mask_list = [torch.from_numpy(edge_mask).long().to(device) for edge_mask in g2_edge_order_mask_list]
        g1_batch = g1_batch.to(device)
        g2_batch = g2_batch.to(device)
        y_pred = model(g1_batch,g1_order,g1_edge_order_mask_list,g2_batch,g2_order,g2_edge_order_mask_list,len_input)
    else: # "BiLSTM"
        device = next(model.parameters()).device # check the device that the model is running on
        x_input = np.zeros((graph_len,opt.max_num_node,input_size))
        y_input = np.zeros((graph_len,opt.max_num_node,2))
        len_input = np.zeros((graph_len))
        gnum = 0
        
        for g in graph:
            len_node = g["len"]
            len_input[gnum] = len_node
            x_input[gnum,:,:] = g["x"]
            y_input[gnum,:,:] = g["pos"]
            gnum = gnum + 1
        y = torch.from_numpy(y_input).float().to(device)

        # Use the trained model to predict coordinates
        x = torch.from_numpy(x_input).float()
        x = Variable(x).to(device)
        y_pred = model(x)

    result = {
      "y":y,
      "y_pred":y_pred,
      "len_input":len_input
    }
    return result
Esempio n. 11
0
#
# In the case of Tree-LSTM, messages start from leaves of the tree, and
# propagate/processed upwards until they reach the roots. A visualization
# is as follows:
#
# .. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif
#    :alt:
#
# DGL defines a generator to perform the topological sort, each item is a
# tensor recording the nodes from bottom level to the roots. One can
# appreciate the degree of parallelism by inspecting the difference of the
# followings:
#

print('\nTraversing one tree:')
print(dgl.topological_nodes_generator(a_tree))
print('Traversing many trees at the same time:')
print(dgl.topological_nodes_generator(graph))

##############################################################################
# We then call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:

import dgl.function as fn
import torch as th
import Transformer_Utils
# cell = TreeLSTMCell(256, 256)
# print("num_nodes:", graph.number_of_nodes())
# graph.ndata['h'] = th.ones(graph.number_of_nodes(), 256)
# graph.ndata['c'] = th.ones(graph.number_of_nodes(), 256)
# graph.ndata['iou'] = th.ones(graph.number_of_nodes(), 256*3)
# graph.register_message_func(cell.message_func)
Esempio n. 12
0
    def forward(self, g, encs):

        if self.training:
            #print("TRAIN----")

            self.spread_encs(g, encs)

            # topological order
            topo_nodes = dgl.topological_nodes_generator(g)

            roots = topo_nodes[0:1]
            others = topo_nodes[1:]

            #root training computations
            g.register_message_func(self.cell.message_func)
            g.register_reduce_func(self.cell.reduce_func)
            g.register_apply_node_func(self.cell.apply_node_func_root)
            g.prop_nodes(roots)
            #print("--------------------ROOT COMPUTED-------------")

            #other nodes training computations
            g.register_apply_node_func(self.cell.apply_node_func)
            g.prop_nodes(others)
            #print("--------------------ALL COMPUTED-------------")

        else:
            #TODO: instead of unbatch and perform single node expansion, use DGL NodeFlow for batched sampling
            #print("EVAL----")
            trees = []
            features = {
                'parent_h': th.zeros(1, 1, self.h_size),
                'parent_output': th.zeros(1, 1, self.cell.num_classes)
            }
            if str(self.cell) == 'DRNNCell':
                features['sibling_h'] = th.zeros(1, 1, self.h_size)
                features['sibling_output'] = th.zeros(1, 1,
                                                      self.cell.num_classes)
            #create only root trees without labels
            for i in range(len(encs)):
                tree = dgl.DGLGraph()
                tree.add_nodes(1, features)
                trees.append(tree)
            g = dgl.batch(trees)  #batch them
            g.ndata['enc'] = encs  #set root encs
            g.ndata['pos'] = th.zeros(len(
                (g.nodes())), self.max_outdegree)  #auxiliary topological info
            g.ndata['depth'] = th.zeros(len(
                (g.nodes())), self.max_depth + 1)  #auxiliary topological info

            #roots cumputations
            topo_nodes = dgl.topological_nodes_generator(g)
            g.register_message_func(self.cell.message_func)
            g.register_reduce_func(self.cell.reduce_func)
            g.register_apply_node_func(self.cell.apply_node_func_root)
            g.prop_nodes(topo_nodes)
            #print("--------------------ROOT (lvl 0): COMPUTED-------------")

            trees = dgl.unbatch(
                g)  #unbatch to deal with single trees expansions

            nodes_id = []
            #single trees expansions
            for i in range(len(trees)):
                nodes_id.append(self.cell.expand(trees[i], 0))

            positions = [i for i in range(len(trees))]
            final_trees = [None] * len(trees)

            self.filter(nodes_id, trees, positions,
                        final_trees)  #take only nodes to process
            # print("--------------------ROOT (lvl 0): EXPANDED-------------")

            depth = 0

            #loop expansions of the lower levels
            while nodes_id:
                #print("DEPTH", depth,"/", self.max_depth)

                g = dgl.batch(
                    trees)  # batch again to computes new nodes states
                batch_nodes_id = self.tree_node_id_to_batch_node_id(
                    trees, nodes_id)  # ids mapping

                g.register_message_func(self.cell.message_func)
                g.register_reduce_func(self.cell.reduce_func)
                g.register_apply_node_func(self.cell.apply_node_func)
                g.prop_nodes(batch_nodes_id)

                depth += 1
                #print("--------------------lvl "+str(depth)+" NODES: COMPUTED-------------")

                if depth < self.max_depth:  #if stopping criteria not reached

                    tree_nodes_id = nodes_id.copy()
                    trees = dgl.unbatch(
                        g)  #unbatch to deal with single trees expansions

                    nodes_id = []
                    # single trees expansions
                    for i in range(len(trees)):
                        tree_ids = []
                        for j in range(len(tree_nodes_id[i])):
                            id = tree_nodes_id[i][j]
                            tree_ids += self.cell.expand(trees[i], id)
                        nodes_id.append(tree_ids)

                    self.filter(nodes_id, trees, positions,
                                final_trees)  #take only nodes to process
                    # print("--------------------lvl "+str(depth)+" NODES: EXPANDED-------------")

                else:  #if stops
                    for i in range(len(trees)):
                        final_trees[positions[i]] = trees[
                            i]  #put on the final trees the last computed nodes
                    break

            g = dgl.batch(final_trees)

        return g
Esempio n. 13
0
    def forward(self, g, encs):

        #self.training = False #DA TOGLIERE SIMULO

        if self.training:
            #print("TRAIN----")

            self.spread_encs(g, encs)

            # topological order
            topo_nodes = dgl.topological_nodes_generator(g)

            roots = topo_nodes[0:1]
            others = topo_nodes[1:]

            #root training computations
            g.register_message_func(self.cell.message_func)
            g.register_reduce_func(self.cell.reduce_func)
            g.register_apply_node_func(self.cell.apply_node_func_root)
            g.prop_nodes(roots)
            #print("--------------------ROOT COMPUTED-------------")

            #other nodes training computations
            g.register_apply_node_func(self.cell.apply_node_func)
            g.prop_nodes(others)
            #print("--------------------ALL COMPUTED-------------")

        else:
            #print("EVAL----")
            trees = []
            features = {'parent_h': th.zeros(1, 1, self.h_size), 'parent_output': th.zeros(1, 1, self.cell.num_classes)}
            if str(self.cell) == 'DRNNCell':
                features['sibling_h'] = th.zeros(1, 1, self.h_size)
                features['sibling_output'] = th.zeros(1, 1, self.cell.num_classes)
            #print(features)
            #input("srthfjyg")
            #create only root trees without labels
            for i in range(len(encs)):
                tree = dgl.DGLGraph()
                tree.add_nodes(1, features)
                #print(tree.ndata)
                trees.append(tree)
            #print("#ROOTS", len(trees))
            g = dgl.batch(trees) #batch them
            g.ndata['enc'] = encs #set root encs
            g.ndata['pos'] = th.zeros(len((g.nodes())),10)
            g.ndata['depth'] = th.zeros(len((g.nodes())),10)

            #roots cumputations
            topo_nodes = dgl.topological_nodes_generator(g)
            g.register_message_func(self.cell.message_func)
            g.register_reduce_func(self.cell.reduce_func)
            g.register_apply_node_func(self.cell.apply_node_func_root)
            g.prop_nodes(topo_nodes)


            #print("--------------------ROOT (lvl 0): h, label, probs COMPUTED-------------")

            trees = dgl.unbatch(g) #unbatch to deal with single trees expansions

            nodes_id = []
            #single trees expansions
            for i in range(len(trees)):
                nodes_id.append(self.cell.expand(trees[i], 0))

            positions = [i for i in range(len(trees))]
            #print("ALBERI", positions)
            final_trees = [None] * len(trees)

            #print("TREES", len(trees), trees)
            #print("FINAL", final_trees)

            #print("NODI da elaborare", nodes_id)

            #nodes_id, filtered_trees, positions = self.filter(nodes_id, trees, positions, final_trees) #<---------- qui mi perdo gli alberi che non devo espandere


            self.filter2(nodes_id, trees, positions, final_trees)

            #print("POS FILTRATI", positions)

            #print("NODI da elaborare FILTRATI", nodes_id)

            #print("TREES FILTRATI", len(trees))
            #print("FINAL FILTRATI", final_trees)

            #for t in final_trees:
                #if t is not None:
                    #t.ndata['parent_h'] = th.zeros(t.number_of_nodes(), 1 ,self.h_size)
                    #t.ndata['parent_output'] = th.zeros(t.number_of_nodes(), 1, self.h_size)
                    #s= ""
                    #for k in t.ndata:
                        #s+= "   "+str(k)
                    #print(s)
                #else:
                    #print(None)

            #input("---------------------------")

            #print("POSITIONS", positions)
            depth = 0

            #loop expansions of the lower levels
            while nodes_id:
                #print("DEPTH", depth)
                # tree_nodes_id[0] = []
                # print("NODES IDS", tree_nodes_id)
                # print("TREES", trees)
                #nodes_id, filtered_trees, positions = self.filter(nodes_id, trees, positions, roots) #<---------- qui mi perdo gli alberi che non devo espandere



                #for i in range(len(pos)):
                    #positions[i] = positions[pos[i]]
                #print("NODI da elaborare filtrati", nodes_id)
                #print("Degli alberi", positions)
                # print("TREES", trees)
                # input("---------")

                g = dgl.batch(trees)  # batch again to computes new nodes data
                batch_nodes_id = self.tree_node_id_to_batch_node_id(trees, nodes_id)  # ids mapping
                #print("RESI per BATCH", batch_nodes_id)
                # input("-------------")

                g.register_message_func(self.cell.message_func)
                g.register_reduce_func(self.cell.reduce_func)
                g.register_apply_node_func(self.cell.apply_node_func)
                g.prop_nodes(batch_nodes_id)

                depth += 1

                #print("--------------------lvl "+str(depth)+" NODES: h, label, probs COMPUTED-------------")

                if depth < self.max_depth:

                    tree_nodes_id = nodes_id.copy()
                    trees = dgl.unbatch(g) #unbatch to deal with single trees expansions

                    #print("BATCH IDS to expand", batch_nodes_id)
                    #print("NODE IDS to expand", tree_nodes_id)

                    nodes_id = []
                    # single trees expansions
                    for i in range(len(trees)):
                        tree_ids = []
                        for j in range(len(tree_nodes_id[i])):
                            id = tree_nodes_id[i][j]
                            tree_ids+=self.cell.expand(trees[i], id)
                        nodes_id.append(tree_ids)
                    #print("NODI da elaborare", nodes_id)

                    #print("TREES", len(trees), trees)
                    #print("FINAL", final_trees)

                    #print("NODI da elaborare", nodes_id)

                    # nodes_id, filtered_trees, positions = self.filter(nodes_id, trees, positions, final_trees) #<---------- qui mi perdo gli alberi che non devo espandere

                    self.filter2(nodes_id, trees, positions, final_trees)

                    #print("POS FILTRATI", positions)

                    #print("NODI da elaborare FILTRATI", nodes_id)

                    #print("TREES FILTRATI", len(trees))
                    #print("FINAL FILTRATI", final_trees)

                else:
                    #devo mettere in final quelli rimasti in trees
                    for i in range(len(trees)):
                        final_trees[positions[i]] = trees[i]
                        #del (trees[positions.index(positions[i])])
                        #del (nodes_id[positions.index(positions[i])])
                        #positions.remove(positions[i])
                    break
                #input("-------------------------")
            #print(roots[1])
            #print(final_trees)
            g = dgl.batch(final_trees)

        return g
Esempio n. 14
0
def tree_depth(tree: dgl.DGLGraph) -> int:
    """
    Compute the maximum distance from a leaf to the root.
    """
    return len(dgl.topological_nodes_generator(tree)) - 1
Esempio n. 15
0
def get_tree_depth(tree: dgl.DGLGraph) -> int:
    return len(dgl.topological_nodes_generator(tree))