Exemple #1
0
def test_simple_readout():
    g1 = dgl.DGLGraph()
    g1.add_nodes(3)
    g2 = dgl.DGLGraph()
    g2.add_nodes(4)  # no edges
    g1.add_edges([0, 1, 2], [2, 0, 1])

    n1 = F.randn((3, 5))
    n2 = F.randn((4, 5))
    e1 = F.randn((3, 5))
    s1 = F.sum(n1, 0)  # node sums
    s2 = F.sum(n2, 0)
    se1 = F.sum(e1, 0)  # edge sums
    m1 = F.mean(n1, 0)  # node means
    m2 = F.mean(n2, 0)
    me1 = F.mean(e1, 0)  # edge means
    w1 = F.randn((3, ))
    w2 = F.randn((4, ))
    max1 = F.max(n1, 0)
    max2 = F.max(n2, 0)
    maxe1 = F.max(e1, 0)
    ws1 = F.sum(n1 * F.unsqueeze(w1, 1), 0)
    ws2 = F.sum(n2 * F.unsqueeze(w2, 1), 0)
    wm1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) / F.sum(F.unsqueeze(w1, 1), 0)
    wm2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) / F.sum(F.unsqueeze(w2, 1), 0)
    g1.ndata['x'] = n1
    g2.ndata['x'] = n2
    g1.ndata['w'] = w1
    g2.ndata['w'] = w2
    g1.edata['x'] = e1

    assert F.allclose(dgl.sum_nodes(g1, 'x'), s1)
    assert F.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1)
    assert F.allclose(dgl.sum_edges(g1, 'x'), se1)
    assert F.allclose(dgl.mean_nodes(g1, 'x'), m1)
    assert F.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
    assert F.allclose(dgl.mean_edges(g1, 'x'), me1)
    assert F.allclose(dgl.max_nodes(g1, 'x'), max1)
    assert F.allclose(dgl.max_edges(g1, 'x'), maxe1)

    g = dgl.batch([g1, g2])
    s = dgl.sum_nodes(g, 'x')
    m = dgl.mean_nodes(g, 'x')
    max_bg = dgl.max_nodes(g, 'x')
    assert F.allclose(s, F.stack([s1, s2], 0))
    assert F.allclose(m, F.stack([m1, m2], 0))
    assert F.allclose(max_bg, F.stack([max1, max2], 0))
    ws = dgl.sum_nodes(g, 'x', 'w')
    wm = dgl.mean_nodes(g, 'x', 'w')
    assert F.allclose(ws, F.stack([ws1, ws2], 0))
    assert F.allclose(wm, F.stack([wm1, wm2], 0))
    s = dgl.sum_edges(g, 'x')
    m = dgl.mean_edges(g, 'x')
    max_bg_e = dgl.max_edges(g, 'x')
    assert F.allclose(s, F.stack([se1, F.zeros(5)], 0))
    assert F.allclose(m, F.stack([me1, F.zeros(5)], 0))
    assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
Exemple #2
0
def test_simple_readout():
    g1 = dgl.DGLGraph()
    g1.add_nodes(3)
    g2 = dgl.DGLGraph()
    g2.add_nodes(4)  # no edges
    g1.add_edges([0, 1, 2], [2, 0, 1])

    n1 = th.randn(3, 5)
    n2 = th.randn(4, 5)
    e1 = th.randn(3, 5)
    s1 = n1.sum(0)  # node sums
    s2 = n2.sum(0)
    se1 = e1.sum(0)  # edge sums
    m1 = n1.mean(0)  # node means
    m2 = n2.mean(0)
    me1 = e1.mean(0)  # edge means
    w1 = th.randn(3)
    w2 = th.randn(4)
    ws1 = (n1 * w1[:, None]).sum(0)  # weighted node sums
    ws2 = (n2 * w2[:, None]).sum(0)
    wm1 = (n1 * w1[:, None]).sum(0) / w1[:, None].sum(0)  # weighted node means
    wm2 = (n2 * w2[:, None]).sum(0) / w2[:, None].sum(0)
    g1.ndata['x'] = n1
    g2.ndata['x'] = n2
    g1.ndata['w'] = w1
    g2.ndata['w'] = w2
    g1.edata['x'] = e1

    assert U.allclose(dgl.sum_nodes(g1, 'x'), s1)
    assert U.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1)
    assert U.allclose(dgl.sum_edges(g1, 'x'), se1)
    assert U.allclose(dgl.mean_nodes(g1, 'x'), m1)
    assert U.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
    assert U.allclose(dgl.mean_edges(g1, 'x'), me1)

    g = dgl.batch([g1, g2])
    s = dgl.sum_nodes(g, 'x')
    m = dgl.mean_nodes(g, 'x')
    assert U.allclose(s, th.stack([s1, s2], 0))
    assert U.allclose(m, th.stack([m1, m2], 0))
    ws = dgl.sum_nodes(g, 'x', 'w')
    wm = dgl.mean_nodes(g, 'x', 'w')
    assert U.allclose(ws, th.stack([ws1, ws2], 0))
    assert U.allclose(wm, th.stack([wm1, wm2], 0))
    s = dgl.sum_edges(g, 'x')
    m = dgl.mean_edges(g, 'x')
    assert U.allclose(s, th.stack([se1, th.zeros(5)], 0))
    assert U.allclose(m, th.stack([me1, th.zeros(5)], 0))
Exemple #3
0
    def forward(self, graph, edge_feat, node_feat, g_repr, edge_hidden, node_hidden, graph_hidden):

        graph.edata['edge_feat'] = edge_feat
        graph.ndata['node_feat'] = node_feat
        graph.edata['hidden1'] = edge_hidden[0][0]
        graph.ndata['hidden1'] = node_hidden[0][0]
        graph.edata['hidden2'] = edge_hidden[1][0]
        graph.ndata['hidden2'] = node_hidden[1][0]

        node_trf_func = lambda x : self.compute_node_repr(nodes=x, graph=graph, g_repr=g_repr)
        edge_trf_func = lambda x: self.compute_edge_repr(edges=x, graph=graph, g_repr=g_repr)
        graph.apply_edges(edge_trf_func)
        graph.update_all(self.graph_message_func, self.graph_reduce_func, node_trf_func)

        e_comb = dgl.sum_edges(graph, 'edge_feat')
        n_comb = dgl.sum_nodes(graph, 'node_feat')

        u_out, u_hidden = self.compute_u_repr(n_comb, e_comb, g_repr, graph_hidden)

        e_feat = graph.edata['edge_feat']
        n_feat = graph.ndata['node_feat']

        h_e = (torch.unsqueeze(graph.edata['hidden1'],0),torch.unsqueeze(graph.edata['hidden2'],0))
        h_n =  (torch.unsqueeze(graph.ndata['hidden1'],0),torch.unsqueeze(graph.ndata['hidden2'],0))

        e_keys = list(graph.edata.keys())
        n_keys = list(graph.ndata.keys())
        for key in e_keys:
            graph.edata.pop(key)
        for key in n_keys:
            graph.ndata.pop(key)

        return e_feat, h_e, n_feat, h_n, u_out, u_hidden
Exemple #4
0
    def forward(self, graph, edge_feat, node_feat, g_repr):
        node_trf_func = lambda x: self.compute_node_repr(
            nodes=x, graph=graph, g_repr=g_repr)

        graph.edata['edge_feat'] = edge_feat
        graph.ndata['node_feat'] = node_feat
        edge_trf_func = lambda x: self.compute_edge_repr(
            edges=x, graph=graph, g_repr=g_repr)

        graph.apply_edges(edge_trf_func)
        graph.update_all(self.graph_message_func, self.graph_reduce_func,
                         node_trf_func)

        e_comb = dgl.sum_edges(graph, 'edge_feat')
        n_comb = dgl.sum_nodes(graph, 'node_feat')

        e_out = graph.edata['edge_feat']
        n_out = graph.ndata['node_feat']

        e_keys = list(graph.edata.keys())
        n_keys = list(graph.ndata.keys())
        for key in e_keys:
            graph.edata.pop(key)
        for key in n_keys:
            graph.ndata.pop(key)

        return e_out, n_out, self.compute_u_repr(n_comb, e_comb, g_repr)
    def forward(self, g, x, e, snorm_n, snorm_e):
        # h = self.embedding_h(h)
        # h = self.in_feat_dropout(h)

        h = torch.zeros([g.number_of_edges(),self.h_dim]).float().to(self.device)
        src, dst = g.all_edges()

        for mpnn in self.layers:
            if self.edge_f:
                if self.dst_f:
                    h = mpnn(g, src_feat = x[src], dst_feat = x[dst], e_feat = e, h_feat = h, snorm_e = snorm_e)
                else:
                    h = mpnn(g, src_feat=x[src], e_feat=e, h_feat=h, snorm_e=snorm_e)

            else:
                if self.dst_f:
                    h = mpnn(g, src_feat=x[src], dst_feat=x[dst], h_feat=h, snorm_e=snorm_e)
                else:
                    h = mpnn(g, src_feat=x[src], h_feat=h, snorm_e=snorm_e)


        g.edata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_edges(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_edges(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_edges(g, 'h')
        else:
            hg = dgl.mean_edges(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Exemple #6
0
    def forward(self, g, x, e, snorm_n, snorm_e):
        # snorm_n batch中用到的
        # h = self.embedding_h(h)
        # h = self.in_feat_dropout(h)

        h_node = torch.zeros([g.number_of_nodes(),self.node_in_dim]).float().to(self.device)
        h_edge = torch.zeros([g.number_of_edges(),self.h_dim]).float().to(self.device)
        src, dst = g.all_edges()

        for edge_layer, node_layer in zip(self.edge_layers, self.node_layers):
            if self.edge_f:
                if self.dst_f:
                    h_edge = edge_layer(g, src_feat = x[src], dst_feat = x[dst], e_feat = e, h_feat = h_edge, snorm_e = snorm_e)
                    h_node = node_layer(g, src_feat=x[src], dst_feat=x[dst], e_feat=e, h_feat=h_node, snorm_e=snorm_e, n_feat = x)
                else:
                    h_edge = edge_layer(g, src_feat=x[src], e_feat=e, h_feat=h_edge, snorm_e=snorm_e)
                    h_node = node_layer(g, src_feat=x[src], e_feat=e, h_feat=h_node, snorm_e=snorm_e, n_feat = x)

            else:
                if self.dst_f:
                    h_edge = edge_layer(g, src_feat=x[src], dst_feat=x[dst], h_feat=h_edge, snorm_e=snorm_e)
                    h_node = node_layer(g, src_feat=x[src], dst_feat=x[dst], h_feat=h_node, snorm_e=snorm_e, n_feat = x)
                else:
                    h_edge = edge_layer(g, src_feat=x[src], h_feat=h_edge, snorm_e=snorm_e)
                    h_node = node_layer(g, src_feat=x[src], h_feat=h_node, snorm_e=snorm_e, n_feat = x)


        g.edata['h'] = h_edge
        if self.node_update:
            g.ndata['h'] = h_node

        # print("g.data:", g.ndata['h'][0].shape)

        if self.readout == "sum":
            he = dgl.sum_edges(g, 'h')
            hn = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            he = dgl.max_edges(g, 'h')
            hn = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            he = dgl.mean_edges(g, 'h')
            hn = dgl.mean_nodes(g, 'h')
        else:
            he = dgl.mean_edges(g, 'h')  # default readout is mean nodes
            hn = dgl.mean_nodes(g, 'h')

        # print(torch.cat([he, hn], dim=1).shape)
        # used to global task

        out = self.Global_MLP_layer(torch.cat([he, hn], dim=1))

        # used to transition task
        edge_out = self.edge_MLPReadout(h_edge)

        # return self.MLP_layer(he)
        return out
Exemple #7
0
    def forward(self,
                graph,
                edge_feats_u,
                node_feats_u,
                edge_feat_reflected_u,
                mode="train",
                node_probability=None,
                joint_acts=None):

        graph.edata['edge_feat_u'] = edge_feats_u
        graph.edata['edge_feat_reflected_u'] = edge_feat_reflected_u
        graph.ndata['node_feat_u'] = node_feats_u

        n_weights = torch.zeros([node_feats_u.shape[0], 1])

        zero_indexes, offset = [0], 0
        num_nodes = graph.batch_num_nodes

        # Mark all 0-th index nodes
        for a in num_nodes[:-1]:
            offset += a
            zero_indexes.append(offset)

        n_weights[zero_indexes] = 1
        graph.ndata['weights'] = n_weights
        graph.ndata['mod_weights'] = 1 - n_weights

        graph.apply_nodes(self.compute_node_data)
        graph.apply_edges(self.compute_edge_data)

        self.utils_storage["indiv"].append(
            graph.ndata["indiv_util"].detach().numpy())
        self.utils_storage["pairs"].append(
            graph.edata["util_vals"].detach().numpy())
        self.utils_storage["batch_num_nodes"].append(graph.batch_num_nodes)
        self.utils_storage["batch_num_edges"].append(graph.batch_num_edges)

        if "inference" in mode:
            graph.ndata["probs"] = node_probability
            src, dst = graph.edges()
            src_list, dst_list = src.tolist(), dst.tolist()

            # Mark edges not connected to zero
            e_nc_zero_weight = torch.zeros([edge_feats_u.shape[0], 1])
            all_nc_edges = [
                idx for idx, (src, dst) in enumerate(zip(src_list, dst_list))
                if (not src in zero_indexes) and (not dst in zero_indexes)
            ]
            e_nc_zero_weight[all_nc_edges] = 0.5
            graph.edata["nc_zero_weight"] = e_nc_zero_weight

            graph.apply_edges(self.graph_pair_inference_func)
            graph.update_all(message_func=self.graph_dst_inference_func,
                             reduce_func=self.graph_reduce_func,
                             apply_node_func=self.graph_node_inference_func)

            total_connected = dgl.sum_nodes(graph, 'util_dst', 'weights')
            total_n_connected = dgl.sum_edges(graph, 'edge_all_sum_prob',
                                              'nc_zero_weight')
            total_expected_others_util = dgl.sum_nodes(
                graph, "expected_indiv_util", "mod_weights").view(-1, 1)
            total_indiv_util_zero = dgl.sum_nodes(graph, "indiv_util",
                                                  "weights")

            returned_values = (total_connected + total_n_connected) + \
                              (total_expected_others_util + total_indiv_util_zero)

            e_keys = list(graph.edata.keys())
            n_keys = list(graph.ndata.keys())

            for key in e_keys:
                graph.edata.pop(key)

            for key in n_keys:
                graph.ndata.pop(key)

            return returned_values

        m_func = lambda x: self.graph_u_sum(graph, x, joint_acts)
        graph.update_all(message_func=m_func, reduce_func=self.graph_sum_all)

        indiv_u_zeros = graph.ndata['indiv_util']
        u_msg_sum_zeros = 0.5 * graph.ndata['u_msg_sum']

        graph.ndata['utils_sum_all'] = (
            indiv_u_zeros + u_msg_sum_zeros).gather(
                -1,
                torch.Tensor(joint_acts)[:, None].long())
        q_values = dgl.sum_nodes(graph, 'utils_sum_all')

        e_keys = list(graph.edata.keys())
        n_keys = list(graph.ndata.keys())

        for key in e_keys:
            graph.edata.pop(key)

        for key in n_keys:
            graph.ndata.pop(key)

        return q_values
Exemple #8
0
 def forward(self, graph, feat):
     with graph.local_scope():
         graph.edata['e'] = feat
         readout = dgl.sum_edges(graph, 'e')
         return readout
Exemple #9
0
 def sum_readout(g):
     return dgl.sum_edges(g, from_field)