def core_process(self, latent0: GraphBatch, data: GraphBatch) -> GraphBatch: e = torch.cat([latent0.e, data.e], dim=1) x = torch.cat([latent0.x, data.x], dim=1) g = torch.cat([latent0.g, data.g], dim=1) data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx) e, x, g = self.core(data) return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
def _forward_core(self, latent0, data): e = torch.cat([latent0.e, data.e], dim=1) x = torch.cat([latent0.x, data.x], dim=1) g = torch.cat([latent0.g, data.g], dim=1) data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx) e, x, g = self.core(data) return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
def to(batch, device, **kwargs): return GraphBatch( batch.x.to(device, **kwargs), batch.e.to(device, **kwargs), batch.g.to(device, **kwargs), batch.edges.to(device, **kwargs), batch.node_idx.to(device, **kwargs), batch.edge_idx.to(device, **kwargs), )
def forward(self, data, steps, save_all: bool = False): # encoded e, x, g = self.encoder(data) data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx) # graph topography data edges = data.edges node_idx = data.node_idx edge_idx = data.edge_idx latent0 = data meta = (edges, node_idx, edge_idx) outputs = [] for _ in range(steps): # core processing step e = torch.cat([latent0.e, e], dim=1) x = torch.cat([latent0.x, x], dim=1) g = torch.cat([latent0.g, g], dim=1) data = GraphBatch(x, e, g, *meta) e, x, g = self.core(data) # decode data = GraphBatch(x, e, g, *meta) _e, _x, _g = self.decoder(data) decoded = GraphBatch(_x, _e, _g, *meta) # transform _e, _x, _g = self.output_transform(decoded) print() gt = GraphBatch(_x, _e, _g, edges, node_idx, edge_idx) if save_all: outputs.append(gt) else: outputs = [gt] return outputs
def out_transform(self, data: GraphBatch) -> GraphBatch: e, x, g = self.output_transform(data) return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
def decode(self, data: GraphBatch) -> GraphBatch: e, x, g = self.decoder(data) return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
def add_edges(data: GraphData, fill_value: Union[float, int], kind: str) -> GraphData: """Adds edges to the :class:`caldera.data.GraphData`. :param data: :class:`caldera.data.GraphData` instance. :param fill_value: fill value for edge attribute tensor. :param kind: Choose from "self" (for self edges), "complete" (for complete graph) or "undirected" (undirected edges) :return: """ UNDIRECTED = "undirected" COMPLETE = "complete" SELF = "self" VALID_KIND = [UNDIRECTED, COMPLETE, SELF] if kind not in VALID_KIND: raise ValueError("'kind' must be one of {}".format(VALID_KIND)) data_cls = data.__class__ if issubclass(data_cls, GraphBatch): node_idx = data.node_idx edge_idx = data.edge_idx elif issubclass(data_cls, GraphData): node_idx = torch.zeros(data.x.shape[0], dtype=torch.long) edge_idx = torch.zeros(data.e.shape[0], dtype=torch.long) else: raise ValueError("data must be a subclass of {} or {}".format( GraphBatch.__class__.__name__, GraphData.__class__.__name__)) with torch.no_grad(): # we count the number of nodes in each graph using node_idx gidx, n_nodes = torch.unique(node_idx, return_counts=True, sorted=True) _, n_edges = torch.unique(edge_idx, return_counts=True, sorted=True) eidx = 0 nidx = 0 graph_edges_list = [] new_edges_lengths = torch.zeros(gidx.shape[0], dtype=torch.long) for _gidx, _n_nodes, _n_edges in zip(gidx, n_nodes, n_edges): graph_edges = data.edges[:, eidx:eidx + _n_edges] if kind == UNDIRECTED: missing_edges = edges_difference(graph_edges.flip(0), graph_edges) elif kind == COMPLETE: all_graph_edges = _create_all_edges(nidx, _n_nodes) missing_edges = edges_difference(all_graph_edges, graph_edges) elif kind == SELF: self_edges = torch.cat( [torch.arange(nidx, nidx + _n_nodes).expand(1, -1)] * 2) missing_edges = edges_difference(self_edges, graph_edges) graph_edges_list.append(missing_edges) if not missing_edges.shape[0]: new_edges_lengths[_gidx] = 0 else: new_edges_lengths[_gidx] = missing_edges.shape[1] nidx += _n_nodes eidx += _n_edges new_edges = torch.cat(graph_edges_list, axis=1) new_edge_idx = gidx.repeat_interleave(new_edges_lengths) new_edge_attr = torch.full((new_edges.shape[1], data.e.shape[1]), fill_value=fill_value) edges = torch.cat([data.edges, new_edges], axis=1) edge_idx = torch.cat([edge_idx, new_edge_idx]) edge_attr = torch.cat([data.e, new_edge_attr], axis=0) idx = edge_idx.argsort() if issubclass(data_cls, GraphBatch): return GraphBatch( node_attr=data.x.detach().clone(), edge_attr=edge_attr[idx], global_attr=data.g.detach().clone(), edges=edges[:, idx], node_idx=data.node_idx.detach().clone(), edge_idx=edge_idx[idx], ) elif issubclass(data_cls, GraphData): return data_cls( node_attr=data.x.detach().clone(), edge_attr=edge_attr[idx], global_attr=data.g.detach().clone(), edges=edges[:, idx], )
def _forward_out(self, data): e, x, g = self.output_transform(data) return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
def _forward_decode(self, data): e, x, g = self.decoder(data) return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)