def track_time(batch_size, feat_size, readout_op, type): device = utils.get_bench_device() ds = dgl.data.QM7bDataset() # prepare graph graphs = ds[0:batch_size][0] g = dgl.batch(graphs).to(device) if type == 'node': g.ndata['h'] = torch.randn((g.num_nodes(), feat_size), device=device) for i in range(10): out = dgl.readout_nodes(g, 'h', op=readout_op) with utils.Timer() as t: for i in range(50): out = dgl.readout_nodes(g, 'h', op=readout_op) elif type == 'edge': g.edata['h'] = torch.randn((g.num_edges(), feat_size), device=device) for i in range(10): out = dgl.readout_edges(g, 'h', op=readout_op) with utils.Timer() as t: for i in range(50): out = dgl.readout_edges(g, 'h', op=readout_op) else: raise Exception("Unknown type") return t.elapsed_secs / 50
def forward(self, graph: dgl.DGLGraph): with torch.no_grad(): self.node_gnn(graph) readouts_to_cat = [ dgl.readout_nodes(graph, 'feat', op=aggr) for aggr in self.readout_aggregators ] readout = torch.cat(readouts_to_cat, dim=-1) return self.output(readout)
def forward(self, g): self.gnn(g) readouts_to_cat = [ dgl.readout_nodes(g, 'feat', op=aggr) for aggr in self.readout_aggregators ] readout = torch.cat(readouts_to_cat, dim=-1) return self.output(readout)
def test_weighted_reduce_readout(g, idtype, reducer): g = g.astype(idtype).to(F.ctx()) g.ndata['h'] = F.randn((g.number_of_nodes(), 3)) g.ndata['w'] = F.randn((g.number_of_nodes(), 1)) g.edata['h'] = F.randn((g.number_of_edges(), 2)) g.edata['w'] = F.randn((g.number_of_edges(), 1)) # Test.1: node readout x = dgl.readout_nodes(g, 'h', 'w', op=reducer) # check correctness subg = dgl.unbatch(g) subx = [] for sg in subg: sx = dgl.readout_nodes(sg, 'h', 'w', op=reducer) subx.append(sx) assert F.allclose(x, F.cat(subx, dim=0)) x = getattr(dgl, '{}_nodes'.format(reducer))(g, 'h', 'w') # check correctness subg = dgl.unbatch(g) subx = [] for sg in subg: sx = getattr(dgl, '{}_nodes'.format(reducer))(sg, 'h', 'w') subx.append(sx) assert F.allclose(x, F.cat(subx, dim=0)) # Test.2: edge readout x = dgl.readout_edges(g, 'h', 'w', op=reducer) # check correctness subg = dgl.unbatch(g) subx = [] for sg in subg: sx = dgl.readout_edges(sg, 'h', 'w', op=reducer) subx.append(sx) assert F.allclose(x, F.cat(subx, dim=0)) x = getattr(dgl, '{}_edges'.format(reducer))(g, 'h', 'w') # check correctness subg = dgl.unbatch(g) subx = [] for sg in subg: sx = getattr(dgl, '{}_edges'.format(reducer))(sg, 'h', 'w') subx.append(sx) assert F.allclose(x, F.cat(subx, dim=0))
def forward(self, g): with g.local_scope(): g.edata['tmp'] = g.edata['m'] * self.dense_rbf(g.edata['rbf']) g.update_all(fn.copy_e('tmp', 'x'), fn.sum('x', 't')) for layer in self.dense_layers: g.ndata['t'] = layer(g.ndata['t']) if self.activation is not None: g.ndata['t'] = self.activation(g.ndata['t']) g.ndata['t'] = self.dense_final(g.ndata['t']) return dgl.readout_nodes(g, 't')
def forward(self, graph: dgl.DGLGraph): with torch.no_grad(): graph3D = deepcopy(graph) self.node_gnn(graph3D) readouts_to_cat3D = [ dgl.readout_nodes(graph3D, 'feat', op=aggr) for aggr in self.frozen_readout_aggregators ] readout3D = torch.cat(readouts_to_cat3D, dim=-1) latent3D = self.output(readout3D).detach() self.node_gnn2D(graph) readouts_to_cat2D = [ dgl.readout_nodes(graph, 'feat', op=aggr) for aggr in self.readout_aggregators ] readout = torch.cat(readouts_to_cat2D + [latent3D], dim=-1) return self.output2D(readout)
def forward(self, graph: dgl.DGLGraph): graph.apply_nodes(self.input_node_func) if self.fourier_encodings > 0: graph.edata['d'] = fourier_encode_dist(graph.edata['d'], num_encodings=self.fourier_encodings).squeeze() for mp_layer in self.mp_layers: mp_layer(graph) graph.apply_nodes(self.output_node_func) readouts_to_cat = [dgl.readout_nodes(graph, 'feat', op=aggr) for aggr in self.readout_aggregators] readout = torch.cat(readouts_to_cat, dim=-1) return self.output(readout)
def forward(self, graph: dgl.DGLGraph): graph.apply_nodes(self.input_node_func) for mp_layer in self.mp_layers: mp_layer(graph) graph.apply_nodes(self.output_node_func) readouts_to_cat = [ dgl.readout_nodes(graph, 'feat', op=aggr) for aggr in self.readout_aggregators ] readout = torch.cat(readouts_to_cat, dim=-1) return self.output(readout)
def forward(self, g): with g.local_scope(): g.edata['tmp'] = g.edata['m'] * self.dense_rbf(g.edata['rbf']) g_reverse = dgl.reverse(g, copy_edata=True) g_reverse.update_all(fn.copy_e('tmp', 'x'), fn.sum('x', 't')) g.ndata['t'] = self.up_projection(g_reverse.ndata['t']) for layer in self.dense_layers: g.ndata['t'] = layer(g.ndata['t']) if self.activation is not None: g.ndata['t'] = self.activation(g.ndata['t']) g.ndata['t'] = self.dense_final(g.ndata['t']) return dgl.readout_nodes(g, 't', op='sum' if self.extensive else 'mean')
def forward(self, graph: dgl.DGLGraph): if self.fourier_encodings > 0: graph.edata['d'] = fourier_encode_dist( graph.edata['d'], num_encodings=self.fourier_encodings) graph.apply_edges(self.input_edge_func) graph.update_all(message_func=self.message_function, reduce_func=self.reduce_func(msg='m', out='m_sum')) if self.node_wise_output_layers > 0: graph.apply_nodes(self.output_node_func) readouts_to_cat = [ dgl.readout_nodes(graph, 'feat', op=aggr) for aggr in self.readout_aggregators ] readout = torch.cat(readouts_to_cat, dim=-1) return self.output(readout)
def forward(self, graph: dgl.DGLGraph): graph.ndata['feat'] = self.node_embedding[None, :].expand( graph.number_of_nodes(), -1) if self.fourier_encodings > 0: graph.edata['d'] = fourier_encode_dist( graph.edata['d'], num_encodings=self.fourier_encodings) graph.apply_edges(self.input_edge_func) for mp_layer in self.mp_layers: mp_layer(graph) if self.node_wise_output_layers > 0: graph.apply_nodes(self.output_node_func) readouts_to_cat = [ dgl.readout_nodes(graph, 'feat', op=aggr) for aggr in self.readout_aggregators ] readout = torch.cat(readouts_to_cat, dim=-1) return self.output(readout)
def track_time(batch_size, feat_size, readout_op, type): device = utils.get_bench_device() ds = dgl.data.QM7bDataset() # prepare graph graphs = ds[0:batch_size][0] g = dgl.batch(graphs).to(device) if type == 'node': g.ndata['h'] = torch.randn((g.num_nodes(), feat_size), device=device) t0 = time.time() for i in range(10): out = dgl.readout_nodes(g, 'h', readout_op) t1 = time.time() elif type == 'edge': g.edata['h'] = torch.randn((g.num_edges(), feat_size), device=device) t0 = time.time() for i in range(10): out = dgl.readout_edges(g, 'h', readout_op) t1 = time.time() else: raise Exception("Unknown type") return (t1 - t0) / 10
def readout(self, g, features): g.ndata["h"] = features h = dgl.readout_nodes(g, "h", op="mean") return h