Example #1
0
def track_time(batch_size, feat_size, readout_op, type):
    device = utils.get_bench_device()
    ds = dgl.data.QM7bDataset()
    # prepare graph
    graphs = ds[0:batch_size][0]

    g = dgl.batch(graphs).to(device)
    if type == 'node':
        g.ndata['h'] = torch.randn((g.num_nodes(), feat_size), device=device)
        for i in range(10):
            out = dgl.readout_nodes(g, 'h', op=readout_op)
        with utils.Timer() as t:
            for i in range(50):
                out = dgl.readout_nodes(g, 'h', op=readout_op)
    elif type == 'edge':
        g.edata['h'] = torch.randn((g.num_edges(), feat_size), device=device)
        for i in range(10):
            out = dgl.readout_edges(g, 'h', op=readout_op)
        with utils.Timer() as t:
            for i in range(50):
                out = dgl.readout_edges(g, 'h', op=readout_op)
    else:
        raise Exception("Unknown type")

    return t.elapsed_secs / 50
 def forward(self, graph: dgl.DGLGraph):
     with torch.no_grad():
         self.node_gnn(graph)
     readouts_to_cat = [
         dgl.readout_nodes(graph, 'feat', op=aggr)
         for aggr in self.readout_aggregators
     ]
     readout = torch.cat(readouts_to_cat, dim=-1)
     return self.output(readout)
Example #3
0
    def forward(self, g):
        self.gnn(g)

        readouts_to_cat = [
            dgl.readout_nodes(g, 'feat', op=aggr)
            for aggr in self.readout_aggregators
        ]
        readout = torch.cat(readouts_to_cat, dim=-1)
        return self.output(readout)
Example #4
0
def test_weighted_reduce_readout(g, idtype, reducer):
    g = g.astype(idtype).to(F.ctx())
    g.ndata['h'] = F.randn((g.number_of_nodes(), 3))
    g.ndata['w'] = F.randn((g.number_of_nodes(), 1))
    g.edata['h'] = F.randn((g.number_of_edges(), 2))
    g.edata['w'] = F.randn((g.number_of_edges(), 1))

    # Test.1: node readout
    x = dgl.readout_nodes(g, 'h', 'w', op=reducer)
    # check correctness
    subg = dgl.unbatch(g)
    subx = []
    for sg in subg:
        sx = dgl.readout_nodes(sg, 'h', 'w', op=reducer)
        subx.append(sx)
    assert F.allclose(x, F.cat(subx, dim=0))

    x = getattr(dgl, '{}_nodes'.format(reducer))(g, 'h', 'w')
    # check correctness
    subg = dgl.unbatch(g)
    subx = []
    for sg in subg:
        sx = getattr(dgl, '{}_nodes'.format(reducer))(sg, 'h', 'w')
        subx.append(sx)
    assert F.allclose(x, F.cat(subx, dim=0))

    # Test.2: edge readout
    x = dgl.readout_edges(g, 'h', 'w', op=reducer)
    # check correctness
    subg = dgl.unbatch(g)
    subx = []
    for sg in subg:
        sx = dgl.readout_edges(sg, 'h', 'w', op=reducer)
        subx.append(sx)
    assert F.allclose(x, F.cat(subx, dim=0))

    x = getattr(dgl, '{}_edges'.format(reducer))(g, 'h', 'w')
    # check correctness
    subg = dgl.unbatch(g)
    subx = []
    for sg in subg:
        sx = getattr(dgl, '{}_edges'.format(reducer))(sg, 'h', 'w')
        subx.append(sx)
    assert F.allclose(x, F.cat(subx, dim=0))
Example #5
0
    def forward(self, g):
        with g.local_scope():
            g.edata['tmp'] = g.edata['m'] * self.dense_rbf(g.edata['rbf'])
            g.update_all(fn.copy_e('tmp', 'x'), fn.sum('x', 't'))

            for layer in self.dense_layers:
                g.ndata['t'] = layer(g.ndata['t'])
                if self.activation is not None:
                    g.ndata['t'] = self.activation(g.ndata['t'])
            g.ndata['t'] = self.dense_final(g.ndata['t'])
            return dgl.readout_nodes(g, 't')
Example #6
0
    def forward(self, graph: dgl.DGLGraph):

        with torch.no_grad():
            graph3D = deepcopy(graph)
            self.node_gnn(graph3D)
            readouts_to_cat3D = [
                dgl.readout_nodes(graph3D, 'feat', op=aggr)
                for aggr in self.frozen_readout_aggregators
            ]
            readout3D = torch.cat(readouts_to_cat3D, dim=-1)
            latent3D = self.output(readout3D).detach()

        self.node_gnn2D(graph)
        readouts_to_cat2D = [
            dgl.readout_nodes(graph, 'feat', op=aggr)
            for aggr in self.readout_aggregators
        ]
        readout = torch.cat(readouts_to_cat2D + [latent3D], dim=-1)

        return self.output2D(readout)
Example #7
0
    def forward(self, graph: dgl.DGLGraph):
        graph.apply_nodes(self.input_node_func)
        if self.fourier_encodings > 0:
            graph.edata['d'] = fourier_encode_dist(graph.edata['d'], num_encodings=self.fourier_encodings).squeeze()

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        graph.apply_nodes(self.output_node_func)

        readouts_to_cat = [dgl.readout_nodes(graph, 'feat', op=aggr) for aggr in self.readout_aggregators]
        readout = torch.cat(readouts_to_cat, dim=-1)
        return self.output(readout)
Example #8
0
    def forward(self, graph: dgl.DGLGraph):
        graph.apply_nodes(self.input_node_func)

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        graph.apply_nodes(self.output_node_func)

        readouts_to_cat = [
            dgl.readout_nodes(graph, 'feat', op=aggr)
            for aggr in self.readout_aggregators
        ]
        readout = torch.cat(readouts_to_cat, dim=-1)
        return self.output(readout)
Example #9
0
    def forward(self, g):
        with g.local_scope():
            g.edata['tmp'] = g.edata['m'] * self.dense_rbf(g.edata['rbf'])
            g_reverse = dgl.reverse(g, copy_edata=True)
            g_reverse.update_all(fn.copy_e('tmp', 'x'), fn.sum('x', 't'))
            g.ndata['t'] = self.up_projection(g_reverse.ndata['t'])

            for layer in self.dense_layers:
                g.ndata['t'] = layer(g.ndata['t'])
                if self.activation is not None:
                    g.ndata['t'] = self.activation(g.ndata['t'])
            g.ndata['t'] = self.dense_final(g.ndata['t'])
            return dgl.readout_nodes(g,
                                     't',
                                     op='sum' if self.extensive else 'mean')
    def forward(self, graph: dgl.DGLGraph):
        if self.fourier_encodings > 0:
            graph.edata['d'] = fourier_encode_dist(
                graph.edata['d'], num_encodings=self.fourier_encodings)
        graph.apply_edges(self.input_edge_func)

        graph.update_all(message_func=self.message_function,
                         reduce_func=self.reduce_func(msg='m', out='m_sum'))

        if self.node_wise_output_layers > 0:
            graph.apply_nodes(self.output_node_func)

        readouts_to_cat = [
            dgl.readout_nodes(graph, 'feat', op=aggr)
            for aggr in self.readout_aggregators
        ]
        readout = torch.cat(readouts_to_cat, dim=-1)
        return self.output(readout)
Example #11
0
    def forward(self, graph: dgl.DGLGraph):
        graph.ndata['feat'] = self.node_embedding[None, :].expand(
            graph.number_of_nodes(), -1)

        if self.fourier_encodings > 0:
            graph.edata['d'] = fourier_encode_dist(
                graph.edata['d'], num_encodings=self.fourier_encodings)
        graph.apply_edges(self.input_edge_func)

        for mp_layer in self.mp_layers:
            mp_layer(graph)

        if self.node_wise_output_layers > 0:
            graph.apply_nodes(self.output_node_func)

        readouts_to_cat = [
            dgl.readout_nodes(graph, 'feat', op=aggr)
            for aggr in self.readout_aggregators
        ]
        readout = torch.cat(readouts_to_cat, dim=-1)
        return self.output(readout)
Example #12
0
def track_time(batch_size, feat_size, readout_op, type):
    device = utils.get_bench_device()
    ds = dgl.data.QM7bDataset()
    # prepare graph
    graphs = ds[0:batch_size][0]

    g = dgl.batch(graphs).to(device)
    if type == 'node':
        g.ndata['h'] = torch.randn((g.num_nodes(), feat_size), device=device)
        t0 = time.time()
        for i in range(10):
            out = dgl.readout_nodes(g, 'h', readout_op)
        t1 = time.time()
    elif type == 'edge':
        g.edata['h'] = torch.randn((g.num_edges(), feat_size), device=device)
        t0 = time.time()
        for i in range(10):
            out = dgl.readout_edges(g, 'h', readout_op)
        t1 = time.time()
    else:
        raise Exception("Unknown type")

    return (t1 - t0) / 10
Example #13
0
 def readout(self, g, features):
     g.ndata["h"] = features
     h = dgl.readout_nodes(g, "h", op="mean")
     return h