Exemple #1
0
def get_index_mapper(graph, next_graph):
    """
    :param graph:
    :param next_graph:
    :return:
    """
    if type(graph) == dgl.BatchedDGLGraph and type(next_graph) == dgl.BatchedDGLGraph:
        graphs = dgl.unbatch(graph)
        next_graphs = dgl.unbatch(next_graph)

        cur_idx = []
        next_idx = []

        curr_num_nodes = 0
        next_num_nodes = 0
        for g, ng in zip(graphs, next_graphs):
            _curr_num_nodes = len(get_filtered_node_index_by_type(g, NODE_ALLY))
            _next_num_nodes = len(get_filtered_node_index_by_type(ng, NODE_ALLY))
            ci, ni = _get_index_mapper_list(g, ng, curr_num_nodes, next_num_nodes)
            cur_idx.extend(ci)
            next_idx.extend(ni)
            curr_num_nodes += _curr_num_nodes
            next_num_nodes += _next_num_nodes
    else:
        cur_idx, next_idx = _get_index_mapper_list(graph, next_graph, 0, 0)

    return cur_idx, next_idx
Exemple #2
0
def make_step_modifications(action, steps):
    selected_actions.append(action)
    with torch.no_grad():
        g1, mp1, g2, mp2 = steps

        glist2 = dgl.unbatch(g2)
        glist1 = dgl.unbatch(g1)
        new_steps = []
        action = torch.reshape(action, (-1, len(action_list)))
        for k in range(len(mp2)):
            actions = torch.nonzero(torch.round(action[k]))
            new_step = [glist2[k], mp2[k]]

            if actions.size()[0] == 0:
                #if no action is above 0.5, just select the max
                actions = torch.tensor([torch.argmax(action[k], 0)])

            new_step = [glist2[k], mp2[k]]

            for i, act in enumerate(actions):

                new_step = action_list[act](new_step)

            _g, _mp = new_step

            new_step_k = [glist1[k], mp1[k], _g, _mp]
            new_steps.append([new_step_k, 0])

        new_steps = my_collate(new_steps)
        return new_steps
def test_batch_unbatch_frame(idtype):
    """Test module of node/edge frames of batched/unbatched DGLGraphs.
    Also address the bug mentioned in https://github.com/dmlc/dgl/issues/1475.
    """
    t1 = tree1(idtype)
    t2 = tree2(idtype)
    N1 = t1.number_of_nodes()
    E1 = t1.number_of_edges()
    N2 = t2.number_of_nodes()
    E2 = t2.number_of_edges()
    D = 10
    t1.ndata['h'] = F.randn((N1, D))
    t1.edata['h'] = F.randn((E1, D))
    t2.ndata['h'] = F.randn((N2, D))
    t2.edata['h'] = F.randn((E2, D))

    b1 = dgl.batch([t1, t2])
    b2 = dgl.batch([t2])
    b1.ndata['h'][:N1] = F.zeros((N1, D))
    b1.edata['h'][:E1] = F.zeros((E1, D))
    b2.ndata['h'][:N2] = F.zeros((N2, D))
    b2.edata['h'][:E2] = F.zeros((E2, D))
    assert not F.allclose(t1.ndata['h'], F.zeros((N1, D)))
    assert not F.allclose(t1.edata['h'], F.zeros((E1, D)))
    assert not F.allclose(t2.ndata['h'], F.zeros((N2, D)))
    assert not F.allclose(t2.edata['h'], F.zeros((E2, D)))

    g1, g2 = dgl.unbatch(b1)
    _g2, = dgl.unbatch(b2)
    assert F.allclose(g1.ndata['h'], F.zeros((N1, D)))
    assert F.allclose(g1.edata['h'], F.zeros((E1, D)))
    assert F.allclose(g2.ndata['h'], t2.ndata['h'])
    assert F.allclose(g2.edata['h'], t2.edata['h'])
    assert F.allclose(_g2.ndata['h'], F.zeros((N2, D)))
    assert F.allclose(_g2.edata['h'], F.zeros((E2, D)))
Exemple #4
0
    def directed_tree_loss(self, all_edus, l_trees_graph, r_trees_graph,
                           trees_graph, roots):

        h_cat, doc_embed, doc_lengths = self.edu_embed_model(all_edus)
        sample_node_embeds = self.split_node_embed(h_cat, doc_lengths)
        batch, seq_len, _ = h_cat.shape
        h_cat_nopadding = th.cat(sample_node_embeds)
        trees_graph.ndata['h'] = h_cat_nopadding
        trees_graph.ndata['ch_h'] = th.zeros_like(h_cat_nopadding)
        trees_graph.register_message_func(self.message_func)
        trees_graph.register_reduce_func(
            lambda x: self.reduce_func(x, doc_embed, batch, seq_len))
        trees_graph.pull(trees_graph.nodes())
        del trees_graph.ndata['h']
        del trees_graph.ndata['ch_h']

        left_adj, right_adj = [], []
        for i, (l_trees_subg, r_trees_subg) in enumerate(
                zip(dgl.unbatch(l_trees_graph), dgl.unbatch(r_trees_graph))):
            left_adj.append(l_trees_subg.reverse() \
                            .adjacency_matrix(transpose=True, ctx=th.device(self.config[DEVICE])) \
                            .to_dense() \
                            .unsqueeze(0))
            right_adj.append(r_trees_subg.reverse().adjacency_matrix(transpose=True, ctx=th.device(self.config[DEVICE])) \
                             .to_dense() \
                             .unsqueeze(0))

        left_adj = th.cat(left_adj)
        right_adj = th.cat(right_adj)
        compat_matrix = self.get_compat_matrix(h_cat)
        root_scores = self.root_clf(h_cat).view(h_cat.shape[0], -1)
        self.total_score += self.logistic_loss(compat_matrix,
                                               (left_adj, right_adj),
                                               (root_scores, roots)) / batch
        return self.total_score
    def directed_tree_loss(self, all_edus, l_trees_graph, r_trees_graph,
                           trees_graph, roots):

        h_cat, doc_embed, doc_lengths = self.edu_embed_model(all_edus)
        sample_node_embeds = self.split_node_embed(h_cat, doc_lengths)
        batch, seq_len, _ = h_cat.shape
        left_adj, right_adj = [], []
        for i, (l_trees_subg, r_trees_subg) in enumerate(
                zip(dgl.unbatch(l_trees_graph), dgl.unbatch(r_trees_graph))):
            left_adj.append(l_trees_subg.reverse() \
                            .adjacency_matrix(transpose=True, ctx=th.device(self.config[DEVICE])) \
                            .to_dense() \
                            .unsqueeze(0))
            right_adj.append(r_trees_subg.reverse().adjacency_matrix(transpose=True, ctx=th.device(self.config[DEVICE])) \
                             .to_dense() \
                             .unsqueeze(0))
        self.total_score = 0
        left_adj = th.cat(left_adj)
        right_adj = th.cat(right_adj)
        compat_matrix = self.get_compat_matrix(h_cat)
        root_scores = self.root_clf(h_cat).view(h_cat.shape[0], -1)
        self.total_score += self.logistic_loss(compat_matrix,
                                               (left_adj, right_adj),
                                               (root_scores, roots)) / batch
        return self.total_score
    def forward(self, bg, bg_out_hr,feature_name='energy',out_name='neu_energy'):
        
        output_gr = []   
        
        with bg.local_scope():
            with bg_out_hr.local_scope():
                
                graph_list = dgl.unbatch(bg)
                graph_list_out_hr = dgl.unbatch(bg_out_hr)
                
                for ig in range(len(graph_list)) :
                    
                    g = graph_list[ig]
                    g_out_hr = graph_list_out_hr[ig]
                
                    data = g.ndata[feature_name] 
                    data = torch.reshape(data, (data.shape[0],) )
                    b_factors = g.ndata['broadcast']

                    out = torch.repeat_interleave(data,b_factors,dim=0)

                    g_out_hr.ndata[out_name] = out[:, None]
                    
                    output_gr.append(g_out_hr  )
        
        return dgl.batch(output_gr)
Exemple #7
0
    def forward(self, g1, g2, mode='pairs'):
        ''' mode:   'pairs' expect paired graphs, same for g1 and g2.
                    'retrieval' g1 is just one graph and computes the distance against all graphs in g2
        '''

        g1_list = dgl.unbatch(g1)
        if len(g1_list) > 1:
            for i, g in enumerate(g1_list):
                g.gdata = {}
                g.gdata['std'] = g1.gdata['std'][i]

        g2_list = dgl.unbatch(g2)
        for i, g in enumerate(g2_list):
            g.gdata = {}
            g.gdata['std'] = g2.gdata['std'][i]

        d = []
        for i in range(len(g2_list)):
            if mode == 'pairs':
                d_aux = self.soft_hausdorff(g1_list[i], g2_list[i])
            elif mode == 'retrieval':
                query = g1_list[0]
                d_aux = self.soft_hausdorff(query, g2_list[i])
            else:
                raise NameError(mode + ' not implemented!')
            d.append(d_aux)
        d = torch.stack(d)
        return d
Exemple #8
0
    def forward(self, graphs, need_weights=False):
        encoding = torch.cat(
            [self.encoder.encode(graph) for graph in dgl.unbatch(graphs)],
            dim=0)
        embedding = self.node_embedder(graphs.ndata['atomic'].type(torch.long))
        if self.concatenate_encoding:
            graphs.ndata['h'] = torch.cat((encoding, embedding.squeeze()),
                                          dim=-1)
        else:
            graphs.ndata['h'] = encoding + embedding.squeeze()

        batch = []
        for g in dgl.unbatch(graphs):
            batch.append(g.ndata['h'])
        h = torch.nn.utils.rnn.pad_sequence(batch)

        attentions_ = []
        for block in self.blocks:
            h, att_ = block(h)
            if need_weights: attentions_.append(att_)

        truncated = [
            h[:num_nodes, i, :]
            for i, num_nodes in enumerate(graphs.batch_num_nodes())
        ]
        h = torch.cat(truncated, dim=0)

        if need_weights:
            return h, attentions_
        return h
    def forward(self, batch):
        """Compute tree-lstm prediction given a batch.

        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.

        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        #----------utils function---------------
        def InitS(tree):
            tree.ndata['s'] = tree.ndata['e'].mean(dim=0).repeat(tree.number_of_nodes(), 1)
            return tree

        def updateS(tree, state):
            assert state.dim() == 1
            tree.ndata['s'] = state.repeat(tree.number_of_nodes(), 1)
            return tree

        def extractS(batchTree):
            # [dmodel] --> [[dmodel]] --> [tree, dmodel] --> [tree, 1, dmodel]
            s_list = [tree.ndata.pop('s')[0].unsqueeze(0) for tree in dgl.unbatch(batchTree)]
            return th.cat(s_list, dim=0).unsqueeze(1)

        def extractH(batchTree):
            # [nodes, dmodel] --> [nodes, dmodel]--> [max_nodes, dmodel]--> [tree*_max_nodes, dmodel] --> [tree, max_nodes, dmodel]
            h_list = [tree.ndata.pop('h') for tree in dgl.unbatch(batchTree)]
            max_nodes = max([h.size(0) for h in h_list])
            h_list = [th.cat([h, th.zeros([max_nodes-h.size(0), h.size(1)]).to(self.device)], dim=0).unsqueeze(0) for h in h_list]
            return th.cat(h_list, dim=0)
        #-----------------------------------------

        g = batch.graph
        # feed embedding
        embeds = self.embedding(batch.wordid * batch.mask)
        g.ndata['c'] = th.zeros((g.number_of_nodes(), 2, self.dmodel)).to(self.device)
        g.ndata['e'] = embeds*batch.mask.float().unsqueeze(-1)
        g.ndata['h'] = embeds*batch.mask.float().unsqueeze(-1)
        g = dgl.batch([InitS(gg) for gg in dgl.unbatch(g)])
        # propagate
        for i in range(self.T_step):
            g.register_message_func(self.cell.message_func)
            g.register_reduce_func(self.cell.reduce_func)
            g.register_apply_node_func(self.cell.apply_node_func)
            dgl.prop_nodes_topo(g)
            States = self.cell.updateGlobalVec(extractS(g), extractH(g) )
            g = dgl.batch([updateS(tree, state) for (tree, state) in zip(dgl.unbatch(g), States)])
        # compute logits
        h = self.dropout(g.ndata.pop('h'))
        logits = self.linear(h)
        return logits
Exemple #10
0
    def forward(self, bg, bg_u):

        output_gr = []

        with bg.local_scope():
            with bg_u.local_scope():

                graph_list = dgl.unbatch(bg)
                graph_list_u = dgl.unbatch(bg_u)

                for ig in range(len(graph_list)):

                    g = graph_list[ig]
                    g_u = graph_list_u[ig]

                    #print('----- start filling -------')
                    n_unpooled_node = g.ndata['parent_node'][0]

                    selected_nodes = g.ndata['_ID'][:, None]  #a
                    pooled_node_features = g.ndata['energy']  #b

                    #                     print('selected_nodes shape : ', selected_nodes.shape)
                    #                     print('pooled_node_features : ', pooled_node_features.shape)

                    expanded_node = selected_nodes.expand_as(
                        pooled_node_features)  #c
                    expanded_node = expanded_node.to(dev)

                    pooled_node_features = pooled_node_features.to(dev)

                    x = torch.zeros(n_unpooled_node, self.dim, device=dev)
                    #x.to(dev)

                    x.scatter_(0, expanded_node, pooled_node_features)

                    #print('----- end filling -------')

                    g_new = dgl.DGLGraph()
                    g_new.add_nodes(n_unpooled_node)

                    src, dst = g_u.edges()
                    g_new.add_edges(src, dst)

                    g_new.ndata['energy'] = x
                    g_new.ndata['parent_node'] = g_u.ndata['parent_node'][
                        0] * torch.ones([g_new.number_of_nodes()],
                                        dtype=torch.int)
                    g_new.ndata['_ID'] = torch.tensor(g_u.ndata['_ID'],
                                                      dtype=torch.int64)

                    #                     print('Output node energy shape : ', g_new.ndata['energy'].shape)

                    output_gr.append(g_new)

        return dgl.batch(output_gr)
Exemple #11
0
    def forward(self, fact_batch_graph, img_batch_graph, sem_batch_graph):
        fact_graphs = dgl.unbatch(fact_batch_graph)
        img_graphs = dgl.unbatch(img_batch_graph)
        sem_graphs = dgl.unbatch(sem_batch_graph)
        num_graph = len(fact_graphs)
        new_fact_graphs = []
        for i in range(num_graph):
            fact_graph = fact_graphs[i]
            img_graph = img_graphs[i]
            sem_graph = sem_graphs[i]

            fact_graph = self.gcn(fact_graph, img_graph, sem_graph)

            new_fact_graphs.append(fact_graph)
        return dgl.batch(new_fact_graphs)
Exemple #12
0
def test_batch_propagate():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
    bg.register_message_func(lambda edges: {'m': edges.src['h']})
    bg.register_reduce_func(lambda nodes: {'h': F.sum(nodes.mailbox['m'], 1)})
    # get leaves.

    order = []

    # step 1
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]
    order.append((u, v))

    # step 2
    u = [1, 2, 4 + 5, 3 + 5]
    v = [0, 0, 1 + 5, 1 + 5]
    order.append((u, v))

    bg.prop_edges(order)
    t1, t2 = dgl.unbatch(bg)

    assert F.asnumpy(t1.ndata['h'][0]) == 9
    assert F.asnumpy(t2.ndata['h'][1]) == 5
def test_batch_propagate(idtype):
    t1 = tree1(idtype)
    t2 = tree2(idtype)

    bg = dgl.batch([t1, t2])
    _mfunc = lambda edges: {'m': edges.src['h']}
    _rfunc = lambda nodes: {'h': F.sum(nodes.mailbox['m'], 1)}
    # get leaves.

    order = []

    # step 1
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]
    order.append((u, v))

    # step 2
    u = [1, 2, 4 + 5, 3 + 5]
    v = [0, 0, 1 + 5, 1 + 5]
    order.append((u, v))

    bg.prop_edges(order, _mfunc, _rfunc)
    t1, t2 = dgl.unbatch(bg)

    assert F.asnumpy(t1.ndata['h'][0]) == 9
    assert F.asnumpy(t2.ndata['h'][1]) == 5
Exemple #14
0
    def run_episode(self):

        _, _, _ = self.env.reset()
        graph_batch = dgl.unbatch(self.env.graph)
        last_accuracy = []
        for graph in graph_batch:
            A = nx.to_numpy_array(graph.to_networkx())
            signals = np.transpose(self.env.signals)
            action = np.zeros(signals.shape)
            action[signals > 0.5] = 1

            accuracy = [(action.T == self.env.world).mean()]

            for _ in range(self.T - 1):
                last_action = deepcopy(action)
                # pdb.set_trace()
                neighbor_average = A.dot(last_action) / A.sum(axis=1)[:, None]
                action = np.zeros(last_action.shape)
                action[neighbor_average > 0.5] = 1
                accuracy.append((action.T == self.env.world).mean())

            benchmark = norm.cdf(0.5,
                                 loc=0,
                                 scale=np.sqrt(self.env.var /
                                               graph.number_of_nodes()))
            if self.normalize:
                last_accuracy.append(accuracy / benchmark)
            else:
                last_accuracy.append(accuracy)

        last_accuracy = np.array(last_accuracy)

        return np.mean(last_accuracy, axis=0), np.std(
            last_accuracy, axis=0) / np.sqrt(last_accuracy.shape[0])
    def predict_ppo(self, non_fixed_variables, last_visited, temperature):
        """
        Given the state related to a node in the CP search, compute the PPO prediction
        :param non_fixed_variables: variables that are not yet fixed (i.e., must_visit)
        :param last_visited: the last city visited
        :param temperature: the softmax temperature for favoring the exploration
        :return: a vector of probabilities of selecting an action
        """

        self.update_graph_state(non_fixed_variables, last_visited)
        y_pred = self.model(self.input_graph, graph_pooling=False)

        out = dgl.unbatch(y_pred)[0]
        action_probs = out.ndata["n_feat"].squeeze(-1)

        available_tensor = torch.zeros([self.n_city])
        available_tensor[non_fixed_variables] = 1

        action_probs = action_probs + torch.abs(torch.min(action_probs))
        action_probs = action_probs - torch.max(
            action_probs * available_tensor)

        y_pred_list = ActorCritic.masked_softmax(action_probs,
                                                 available_tensor,
                                                 dim=0,
                                                 temperature=temperature)
        y_pred_list = y_pred_list.data.cpu().numpy().flatten()

        return y_pred_list
    def train(self, x, y):
        """
        Compute the loss between (f(x) and y)
        :param x: the input
        :param y: the true value of y
        :return: the loss
        """

        self.model.train()

        graph, _ = list(zip(*x))
        graph_batch = dgl.batch(graph)
        y_pred = self.model(graph_batch, graph_pooling=False)
        y_pred = torch.stack([g.ndata["n_feat"]
                              for g in dgl.unbatch(y_pred)]).squeeze(dim=2)
        y_tensor = torch.FloatTensor(np.array(y))

        if self.args.mode == 'gpu':
            y_tensor = y_tensor.contiguous().cuda()

        loss = F.smooth_l1_loss(y_pred, y_tensor)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()
Exemple #17
0
    def forward(self, data, efeat):

        if self.gnn_type == "rgcn":
            # efeat is etypes; data is node features
            x = self.gnn_object(data, efeat)
        elif self.gnn_type == "gat":
            # data is node features
            x = self.gnn_object(data)
        elif self.gnn_type == "mpnn":
            # data is node features; efeat is edge features
            x = self.gnn_object(data, efeat)

        if not self.robot_node_indexes:
            indexes = []
            n_nodes = 0
            unbatched = dgl.unbatch(self.g)    
            for g in unbatched:
                indexes.append(n_nodes+self.grid_nodes)
                n_nodes += g.number_of_nodes()
        else:
            indexes = self.robot_node_indexes


        logits = torch.squeeze(x, 1).to(device=data.device)
        output = logits[indexes].to(device=data.device)
        # print("filtering by indexes", output.shape)
        outputS = output.shape
        nfeats = (1+len(self.central_grid_nodes))*outputS[1]
        # print("nfeats", nfeats)
        newShape = [(outputS[0]*outputS[1])//nfeats, nfeats]
        # print("final shape", newShape)
        output = output.view(newShape)
        return output
Exemple #18
0
    def forward(self, cand_batch, mol_tree_batch):
        cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = \
                mol2dgl(cand_batch, mol_tree_batch)

        n_samples = len(cand_graphs)

        cand_graphs = batch(cand_graphs)
        cand_line_graph = line_graph(cand_graphs, no_backtracking=True)

        n_nodes = len(cand_graphs.nodes)
        n_edges = len(cand_graphs.edges)

        cand_graphs = self.run(cand_graphs, cand_line_graph,
                               tree_mess_src_edges, tree_mess_tgt_edges,
                               tree_mess_tgt_nodes, mol_tree_batch)

        cand_graphs = unbatch(cand_graphs)
        g_repr = torch.stack(
            [g.get_n_repr()['h'].mean(0) for g in cand_graphs], 0)

        self.n_samples_total += n_samples
        self.n_nodes_total += n_nodes
        self.n_edges_total += n_edges
        self.n_passes += 1

        return g_repr
Exemple #19
0
 def extractS(batchTree):
     # [dmodel] --> [[dmodel]] --> [tree, dmodel] --> [tree, 1, dmodel]
     s_list = [
         tree.ndata.pop('s')[0].unsqueeze(0)
         for tree in dgl.unbatch(batchTree)
     ]
     return th.cat(s_list, dim=0).unsqueeze(1)
Exemple #20
0
    def forward(self, g, features, etypes):

        self.g = g

        h = features
        self.g.edata['norm'] = self.g.edata['norm'].to(device=features.device)

        for layer in self.gnn_layers:
            h = layer(self.g, h, etypes)

        base_index = 0
        batch_number = 0
        unbatched = dgl.unbatch(self.g)
        gnn_output = torch.Tensor(size=(len(unbatched), self.gnn_output)).to(
            device=features.device)
        for g in unbatched:
            num_nodes = g.number_of_nodes()
            gnn_output[batch_number, :] = h[
                base_index, :]  # Output is just the room's node
            # output[batch_number, :] = logits[base_index:base_index+num_nodes, :].mean(dim=0) # Output is the average of all nodes
            base_index += num_nodes
            batch_number += 1

        value = self.value_layers(gnn_output)

        advantage = self.advantage_layers(gnn_output)

        advAverage = torch.mean(advantage, dim=1, keepdim=True)
        Q = value + advantage - advAverage

        return Q
Exemple #21
0
 def decode(self,
            *,
            sample: Dict[str, Any],
            prefix: str = "",
            metadata: Optional[Dict[str, Any]] = None) -> None:
     batched_graph = sample["typed_dgl_graph"].graph
     graphs = unbatch(batched_graph)
     start = 0
     total_number_of_nodes = 0
     bounds = []
     numpy_indexes = sample["indexes"].indexes.cpu().numpy()
     for graph in graphs:
         total_number_of_nodes += graph.number_of_nodes()
         end = bisect_right(numpy_indexes, total_number_of_nodes - 1)
         bounds.append((start, end))
         start = end
     for (start, end), path in zip(bounds, sample["metadata"]):
         path_probas = sample["forward"][start:end, 1]
         path_indexes = sample["indexes"].offsets[start:end]
         predictions = path_indexes[path_probas.argsort(descending=True)]
         if metadata is not None and "metadata" in metadata:
             metadata["metadata"][path] = {
                 index: ["%.8f" % (2**proba)]
                 for index, proba in zip(path_indexes.tolist(),
                                         path_probas.tolist())
             }
         predictions += 1
         print("%s%s %s" %
               (prefix, path, " ".join(map(str, predictions.numpy()))))
Exemple #22
0
def translate_gt_graph_to_adj(gt_graph):
    gt_adjs = []
    gt_g_list = dgl.unbatch(gt_graph)
    for gt_g in gt_g_list:
        gt_list = []
        gt_ids = []

        n_node = gt_g.number_of_nodes()
        srt, dst = gt_g.edges()
        srt, dst = srt.detach().cpu().numpy(), dst.detach().cpu().numpy()

        edge_factor = gt_g.edata['feat'].detach().cpu().numpy()
        assert srt.shape[0] == edge_factor.shape[0]

        for edge_id in set(edge_factor):
            ## operate in the matrix form
            org_g = np.zeros((n_node, n_node))
            edge_factor_edge_id = np.zeros_like(edge_factor)
            idx = np.where(edge_factor == edge_id)[0]
            edge_factor_edge_id[idx] = 1.0
            org_g[srt, dst] = edge_factor_edge_id
            gt_list.append(org_g)
            gt_ids.append(edge_id)

        gt_adjs.append((gt_list, gt_ids))

    return gt_adjs
    def evaluate(self, state_for_action, state_for_value, action, available_tensor):
        """
        Evaluating an action wrt. the current policy
        :param state_for_action: State used to compute the actor output
        :param state_for_value: State used to compute the critic output. Although it it the same as the state_for_action,
        it is not the same object
        :param action: the action that is evaluaed
        :param available_tensor: The actions that are possible.
        :return: the log-probabilities of the action, the critic evaluation of the state, the entropy value
        """

        if self.args.mode == "gpu":
            available_tensor = available_tensor.cuda()

        out = self.action_layer(state_for_action, graph_pooling=False)

        out = [x.ndata["n_feat"] for x in dgl.unbatch(out)]

        action_probs = torch.stack(out).squeeze(-1)
        action_probs = action_probs + torch.abs(torch.min(action_probs, 1, keepdim=True)[0])
        action_probs = action_probs - torch.max(action_probs * available_tensor, 1, keepdim=True)[0]
        action_probs = self.masked_softmax(action_probs, available_tensor, dim=1)

        dist = Categorical(action_probs)
        action_log_probs = dist.log_prob(action)

        dist_entropy = dist.entropy()
        state_value = self.value_layer(state_for_value, graph_pooling=True)
        return action_log_probs, torch.squeeze(state_value), dist_entropy
Exemple #24
0
    def forward(self, graph):
        embedding_output = self.embeddings(graph.ndata['input_ids'],
                                           graph.ndata['position_ids'],
                                           graph.ndata['segment_ids'])

        graph.ndata.pop('input_ids')
        graph.ndata.pop('position_ids')
        graph.ndata.pop('segment_ids')

        hidden_size = embedding_output.size(-1)
        embedding_output = embedding_output.view(-1, hidden_size)

        graph.ndata['h'] = embedding_output

        graph = self.encoder(graph)

        g_list = dgl.unbatch(graph)

        pooled_output = []
        for g in g_list:
            pooled_output.append(g.ndata['h'][0])
        pooled_output = torch.stack(pooled_output, 0)

        pooled_output = self.pooler(pooled_output)
        return graph, pooled_output
Exemple #25
0
def predict(net, loader, dataset, batch_size=50, naf_obj=None, progbar=None):
    predicted_frames = []
    predicted_roles = []

    net.eval()
    with torch.no_grad():
        for gs in loader:

            frame_labels, role_labels, \
                frame_chance, role_chance = net.label(gs)

            node_offset = 0
            for g in dgl.unbatch(gs):
                sentence = dataset.conllu(g)
                for i, token in enumerate(sentence):
                    token.ROLE = role_labels[i + node_offset]
                    token.pROLE = role_chance[i + node_offset]

                    token.FRAME = frame_labels[i + node_offset]
                    token.pFRAME = frame_chance[i + node_offset]
                node_offset += len(g)

                # match the predicate and roles by some simple graph traversal
                # rules
                frames, orphans = make_frames(sentence)

                if naf_obj:
                    write_frames_to_naf(naf_obj, frames, sentence)

            if progbar:
                progbar.next(batch_size)

    if progbar:
        progbar.finish()
Exemple #26
0
def getMaskForBatch(subgraph):
    future_index = 0
    indexes = []
    for g in dgl.unbatch(subgraph):
        indexes.append(future_index)
        future_index += g.number_of_nodes()
    return indexes
    def act(self, graph_state, available_tensor):
        """
        Perform an action following the probabilities outputed by the current actor
        :param graph_state: the current state
        :param available_tensor: [0,1]-vector of available actions
        :return: the action selection, its log-probability, and its probability
        """

        if self.args.mode == "gpu":
            available_tensor = available_tensor.cuda()

        batched_graph = dgl.batch([graph_state, ])

        self.action_layer.eval()
        with torch.no_grad():
            out = self.action_layer(batched_graph, graph_pooling=False)
            out = dgl.unbatch(out)[0]
            action_probs = out.ndata["n_feat"].squeeze(-1)

            #  Doing post-processing on the output to have numerically stable probabilities given that a mask is used
            action_probs = action_probs + torch.abs(torch.min(action_probs))
            action_probs = action_probs - torch.max(action_probs * available_tensor)
            action_probs = self.masked_softmax(action_probs, available_tensor, dim=0)

            dist = Categorical(action_probs)
            action = dist.sample()

        return action, dist.log_prob(action), action_probs
Exemple #28
0
    def clone(self):
        assert self.P is None, "clone not implemented for field P."
        if isinstance(self.X, list):
            X = [x.detach().clone() for x in self.X]
            M = {
                key: value.detach().clone()
                for key, value in self.masks.items()
            }
        elif isinstance(self.X, torch.Tensor):
            X = self.X.detach().clone()
        elif isinstance(self.X, dgl.DGLGraph):
            X = dgl.batch(dgl.unbatch(self.X))
            M = {mask: X.ndata[mask] for mask in self.masks.keys()}
        else:
            assert False, "unhandled type to clone: {}".format(type(self.X))

        return MiniBatch(
            X,
            self.Y.detach().clone(),  # tensor
            copy.copy(self.lengths),  # list
            M,  # dict of tensors
            None,  # ??
            copy.deepcopy(self.data)
            if self.data is not None else None,  # dist
            copy.copy(self.ids) if self.ids is not None else None,  # list
        )
Exemple #29
0
def getMaskForBatch(subgraph):
    first_node_index_in_the_next_graph = 0
    indexes = []
    for g in dgl.unbatch(subgraph):
        indexes.append(first_node_index_in_the_next_graph)
        first_node_index_in_the_next_graph += g.number_of_nodes()
    return indexes
    def test_all(self, dataset: AllDataset, output_dir: str = "test_result"):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            print(
                f"make new dir {os.path.abspath(output_dir)}, and write files into it."
            )
        else:
            print(f'output dir {os.path.abspath(output_dir)} exists !')
        self.load()
        self.eval()
        data_loader = GraphDataLoader(dataset.test,
                                      collate_fn=collate,
                                      batch_size=10,
                                      shuffle=False,
                                      drop_last=False)
        start_time = time.time()
        file_name_index = 1
        for i, (bhg, info) in enumerate(data_loader):
            batch_size = len(info)
            self.forward(bhg)
            for idi, (cg, cd) in enumerate(zip(dgl.unbatch(bhg), info)):
                track_pd_list = graph_and_info_to_df_list(cg, cd)
                # todo
                # pd.set_option('display.max_columns', 10000)
                # print(track_pd_list[0])
                for i_df, df in enumerate(track_pd_list):
                    df.to_csv(os.path.join(output_dir,
                                           str(file_name_index) + ".csv"),
                              index=False)
                    file_name_index += 1

        self.train()
        print(
            f"test time is :{time.time() - start_time:6.2f} s | num_samples : {len(dataset.test)}"
        )