def forward(self, g, x, e, snorm_n, snorm_e):
        # h = self.embedding_h(h)
        # h = self.in_feat_dropout(h)

        h = torch.zeros([g.number_of_edges(),self.h_dim]).float().to(self.device)
        src, dst = g.all_edges()

        for mpnn in self.layers:
            if self.edge_f:
                if self.dst_f:
                    h = mpnn(g, src_feat = x[src], dst_feat = x[dst], e_feat = e, h_feat = h, snorm_e = snorm_e)
                else:
                    h = mpnn(g, src_feat=x[src], e_feat=e, h_feat=h, snorm_e=snorm_e)

            else:
                if self.dst_f:
                    h = mpnn(g, src_feat=x[src], dst_feat=x[dst], h_feat=h, snorm_e=snorm_e)
                else:
                    h = mpnn(g, src_feat=x[src], h_feat=h, snorm_e=snorm_e)


        g.edata['h'] = h

        if self.readout == "sum":
            hg = dgl.sum_edges(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_edges(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_edges(g, 'h')
        else:
            hg = dgl.mean_edges(g, 'h')  # default readout is mean nodes

        return self.MLP_layer(hg)
Beispiel #2
0
def test_simple_readout():
    g1 = dgl.DGLGraph()
    g1.add_nodes(3)
    g2 = dgl.DGLGraph()
    g2.add_nodes(4)  # no edges
    g1.add_edges([0, 1, 2], [2, 0, 1])

    n1 = F.randn((3, 5))
    n2 = F.randn((4, 5))
    e1 = F.randn((3, 5))
    s1 = F.sum(n1, 0)  # node sums
    s2 = F.sum(n2, 0)
    se1 = F.sum(e1, 0)  # edge sums
    m1 = F.mean(n1, 0)  # node means
    m2 = F.mean(n2, 0)
    me1 = F.mean(e1, 0)  # edge means
    w1 = F.randn((3, ))
    w2 = F.randn((4, ))
    max1 = F.max(n1, 0)
    max2 = F.max(n2, 0)
    maxe1 = F.max(e1, 0)
    ws1 = F.sum(n1 * F.unsqueeze(w1, 1), 0)
    ws2 = F.sum(n2 * F.unsqueeze(w2, 1), 0)
    wm1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) / F.sum(F.unsqueeze(w1, 1), 0)
    wm2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) / F.sum(F.unsqueeze(w2, 1), 0)
    g1.ndata['x'] = n1
    g2.ndata['x'] = n2
    g1.ndata['w'] = w1
    g2.ndata['w'] = w2
    g1.edata['x'] = e1

    assert F.allclose(dgl.sum_nodes(g1, 'x'), s1)
    assert F.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1)
    assert F.allclose(dgl.sum_edges(g1, 'x'), se1)
    assert F.allclose(dgl.mean_nodes(g1, 'x'), m1)
    assert F.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
    assert F.allclose(dgl.mean_edges(g1, 'x'), me1)
    assert F.allclose(dgl.max_nodes(g1, 'x'), max1)
    assert F.allclose(dgl.max_edges(g1, 'x'), maxe1)

    g = dgl.batch([g1, g2])
    s = dgl.sum_nodes(g, 'x')
    m = dgl.mean_nodes(g, 'x')
    max_bg = dgl.max_nodes(g, 'x')
    assert F.allclose(s, F.stack([s1, s2], 0))
    assert F.allclose(m, F.stack([m1, m2], 0))
    assert F.allclose(max_bg, F.stack([max1, max2], 0))
    ws = dgl.sum_nodes(g, 'x', 'w')
    wm = dgl.mean_nodes(g, 'x', 'w')
    assert F.allclose(ws, F.stack([ws1, ws2], 0))
    assert F.allclose(wm, F.stack([wm1, wm2], 0))
    s = dgl.sum_edges(g, 'x')
    m = dgl.mean_edges(g, 'x')
    max_bg_e = dgl.max_edges(g, 'x')
    assert F.allclose(s, F.stack([se1, F.zeros(5)], 0))
    assert F.allclose(m, F.stack([me1, F.zeros(5)], 0))
    assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
Beispiel #3
0
    def forward(self, g, x, e, snorm_n, snorm_e):
        # snorm_n batch中用到的
        # h = self.embedding_h(h)
        # h = self.in_feat_dropout(h)

        h_node = torch.zeros([g.number_of_nodes(),self.node_in_dim]).float().to(self.device)
        h_edge = torch.zeros([g.number_of_edges(),self.h_dim]).float().to(self.device)
        src, dst = g.all_edges()

        for edge_layer, node_layer in zip(self.edge_layers, self.node_layers):
            if self.edge_f:
                if self.dst_f:
                    h_edge = edge_layer(g, src_feat = x[src], dst_feat = x[dst], e_feat = e, h_feat = h_edge, snorm_e = snorm_e)
                    h_node = node_layer(g, src_feat=x[src], dst_feat=x[dst], e_feat=e, h_feat=h_node, snorm_e=snorm_e, n_feat = x)
                else:
                    h_edge = edge_layer(g, src_feat=x[src], e_feat=e, h_feat=h_edge, snorm_e=snorm_e)
                    h_node = node_layer(g, src_feat=x[src], e_feat=e, h_feat=h_node, snorm_e=snorm_e, n_feat = x)

            else:
                if self.dst_f:
                    h_edge = edge_layer(g, src_feat=x[src], dst_feat=x[dst], h_feat=h_edge, snorm_e=snorm_e)
                    h_node = node_layer(g, src_feat=x[src], dst_feat=x[dst], h_feat=h_node, snorm_e=snorm_e, n_feat = x)
                else:
                    h_edge = edge_layer(g, src_feat=x[src], h_feat=h_edge, snorm_e=snorm_e)
                    h_node = node_layer(g, src_feat=x[src], h_feat=h_node, snorm_e=snorm_e, n_feat = x)


        g.edata['h'] = h_edge
        if self.node_update:
            g.ndata['h'] = h_node

        # print("g.data:", g.ndata['h'][0].shape)

        if self.readout == "sum":
            he = dgl.sum_edges(g, 'h')
            hn = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            he = dgl.max_edges(g, 'h')
            hn = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            he = dgl.mean_edges(g, 'h')
            hn = dgl.mean_nodes(g, 'h')
        else:
            he = dgl.mean_edges(g, 'h')  # default readout is mean nodes
            hn = dgl.mean_nodes(g, 'h')

        # print(torch.cat([he, hn], dim=1).shape)
        # used to global task

        out = self.Global_MLP_layer(torch.cat([he, hn], dim=1))

        # used to transition task
        edge_out = self.edge_MLPReadout(h_edge)

        # return self.MLP_layer(he)
        return out
Beispiel #4
0
def test_simple_readout():
    g1 = dgl.DGLGraph()
    g1.add_nodes(3)
    g2 = dgl.DGLGraph()
    g2.add_nodes(4)  # no edges
    g1.add_edges([0, 1, 2], [2, 0, 1])

    n1 = th.randn(3, 5)
    n2 = th.randn(4, 5)
    e1 = th.randn(3, 5)
    s1 = n1.sum(0)  # node sums
    s2 = n2.sum(0)
    se1 = e1.sum(0)  # edge sums
    m1 = n1.mean(0)  # node means
    m2 = n2.mean(0)
    me1 = e1.mean(0)  # edge means
    w1 = th.randn(3)
    w2 = th.randn(4)
    ws1 = (n1 * w1[:, None]).sum(0)  # weighted node sums
    ws2 = (n2 * w2[:, None]).sum(0)
    wm1 = (n1 * w1[:, None]).sum(0) / w1[:, None].sum(0)  # weighted node means
    wm2 = (n2 * w2[:, None]).sum(0) / w2[:, None].sum(0)
    g1.ndata['x'] = n1
    g2.ndata['x'] = n2
    g1.ndata['w'] = w1
    g2.ndata['w'] = w2
    g1.edata['x'] = e1

    assert U.allclose(dgl.sum_nodes(g1, 'x'), s1)
    assert U.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1)
    assert U.allclose(dgl.sum_edges(g1, 'x'), se1)
    assert U.allclose(dgl.mean_nodes(g1, 'x'), m1)
    assert U.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
    assert U.allclose(dgl.mean_edges(g1, 'x'), me1)

    g = dgl.batch([g1, g2])
    s = dgl.sum_nodes(g, 'x')
    m = dgl.mean_nodes(g, 'x')
    assert U.allclose(s, th.stack([s1, s2], 0))
    assert U.allclose(m, th.stack([m1, m2], 0))
    ws = dgl.sum_nodes(g, 'x', 'w')
    wm = dgl.mean_nodes(g, 'x', 'w')
    assert U.allclose(ws, th.stack([ws1, ws2], 0))
    assert U.allclose(wm, th.stack([wm1, wm2], 0))
    s = dgl.sum_edges(g, 'x')
    m = dgl.mean_edges(g, 'x')
    assert U.allclose(s, th.stack([se1, th.zeros(5)], 0))
    assert U.allclose(m, th.stack([me1, th.zeros(5)], 0))
Beispiel #5
0
 def forward(self, dgl_data):
     dgl_feat, _ = torch.max(
         torch.stack([
             dgl.mean_nodes(dgl_data, 'h'),
             dgl.max_nodes(dgl_data, 'h'),
             dgl.mean_edges(dgl_data, 'h'),
             dgl.max_edges(dgl_data, 'h'),
         ], 2), -1)
     return dgl_feat
Beispiel #6
0
 def forward(self, dgl_data):
     if self.getnode and self.getedge:
         dgl_feat = torch.cat([
             dgl.mean_nodes(dgl_data, 'h'),
             dgl.max_nodes(dgl_data, 'h'),
             dgl.mean_edges(dgl_data, 'h'),
             dgl.max_edges(dgl_data, 'h'),
         ], -1)
     elif self.getnode:
         dgl_feat = torch.cat(
             [dgl.mean_nodes(dgl_data, 'h'),
              dgl.max_nodes(dgl_data, 'h')], -1)
     else:
         dgl_feat = torch.cat(
             [dgl.mean_edges(dgl_data, 'h'),
              dgl.max_edges(dgl_data, 'h')], -1)
     dgl_predict = self.activate(self.weight_node(dgl_feat))
     return dgl_predict
    def compute_disentangle_loss(self):
        assert self.g is not None, "compute disentangle loss need to be called after forward pass"

        # compute discrimination loss
        factors_feat = [
            self.graph_to_feat(self.g, self.hidden,
                               f"factor_{latent_i}").squeeze()
            for latent_i in range(self.n_latent)
        ]

        labels = [
            torch.ones(f.shape[0]) * i for i, f in enumerate(factors_feat)
        ]
        labels = torch.cat(tuple(labels), 0).long().cuda()
        factors_feat = torch.cat(tuple(factors_feat), 0)

        pred = self.classifier(factors_feat)
        discrimination_loss = self.loss_fn(pred, labels)

        # list_num_edges = torch.tensor(self.g.batch_num_edges).unsqueeze(1)
        latent_mean = [
            dgl.mean_edges(self.g, f"factor_{latent_i}")
            for latent_i in range(self.n_latent)
        ]
        latent_mean = torch.cat(tuple(latent_mean), dim=1)
        # list_num_edges = list_num_edges.to(latent_sum.device)
        # norm_latent_sum = latent_sum / list_num_edges)

        latent_mean_distrib = torch_fn.softmax(latent_mean, dim=1)
        latent_mean_entropy = torch.sum(latent_mean_distrib *
                                        torch.log(latent_mean_distrib),
                                        dim=1)

        uniform = torch_fn.softmax(torch.ones_like(latent_mean), dim=1)
        upper_bound = torch.sum(uniform * torch.log(uniform), dim=1)

        distribution_loss = (latent_mean_entropy - upper_bound)
        distribution_loss = torch.mean(distribution_loss) * 100.0

        return [discrimination_loss, distribution_loss]
Beispiel #8
0
 def forward(self, graph, feat):
     with graph.local_scope():
         graph.edata['e'] = feat
         readout = dgl.mean_edges(graph, 'e')
         return readout
Beispiel #9
0
 def mean_readout(g):
     return dgl.mean_edges(g, from_field)