Example #1
0
 def core_process(self, latent0: GraphBatch, data: GraphBatch) -> GraphBatch:
     e = torch.cat([latent0.e, data.e], dim=1)
     x = torch.cat([latent0.x, data.x], dim=1)
     g = torch.cat([latent0.g, data.g], dim=1)
     data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
     e, x, g = self.core(data)
     return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
Example #2
0
 def _forward_core(self, latent0, data):
     e = torch.cat([latent0.e, data.e], dim=1)
     x = torch.cat([latent0.x, data.x], dim=1)
     g = torch.cat([latent0.g, data.g], dim=1)
     data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
     e, x, g = self.core(data)
     return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
Example #3
0
def to(batch, device, **kwargs):
    return GraphBatch(
        batch.x.to(device, **kwargs),
        batch.e.to(device, **kwargs),
        batch.g.to(device, **kwargs),
        batch.edges.to(device, **kwargs),
        batch.node_idx.to(device, **kwargs),
        batch.edge_idx.to(device, **kwargs),
    )
Example #4
0
    def forward(self, data, steps, save_all: bool = False):
        # encoded
        e, x, g = self.encoder(data)
        data = GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)

        # graph topography data
        edges = data.edges
        node_idx = data.node_idx
        edge_idx = data.edge_idx
        latent0 = data

        meta = (edges, node_idx, edge_idx)

        outputs = []
        for _ in range(steps):
            # core processing step
            e = torch.cat([latent0.e, e], dim=1)
            x = torch.cat([latent0.x, x], dim=1)
            g = torch.cat([latent0.g, g], dim=1)
            data = GraphBatch(x, e, g, *meta)
            e, x, g = self.core(data)

            # decode
            data = GraphBatch(x, e, g, *meta)

            _e, _x, _g = self.decoder(data)
            decoded = GraphBatch(_x, _e, _g, *meta)

            # transform
            _e, _x, _g = self.output_transform(decoded)
            print()
            gt = GraphBatch(_x, _e, _g, edges, node_idx, edge_idx)
            if save_all:
                outputs.append(gt)
            else:
                outputs = [gt]

        return outputs
Example #5
0
 def out_transform(self, data: GraphBatch) -> GraphBatch:
     e, x, g = self.output_transform(data)
     return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
Example #6
0
 def decode(self, data: GraphBatch) -> GraphBatch:
     e, x, g = self.decoder(data)
     return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
Example #7
0
def add_edges(data: GraphData, fill_value: Union[float, int],
              kind: str) -> GraphData:
    """Adds edges to the :class:`caldera.data.GraphData`.

    :param data: :class:`caldera.data.GraphData` instance.
    :param fill_value: fill value for edge attribute tensor.
    :param kind: Choose from "self" (for self edges), "complete" (for complete graph) or "undirected" (undirected edges)
    :return:
    """
    UNDIRECTED = "undirected"
    COMPLETE = "complete"
    SELF = "self"
    VALID_KIND = [UNDIRECTED, COMPLETE, SELF]
    if kind not in VALID_KIND:
        raise ValueError("'kind' must be one of {}".format(VALID_KIND))

    data_cls = data.__class__
    if issubclass(data_cls, GraphBatch):
        node_idx = data.node_idx
        edge_idx = data.edge_idx
    elif issubclass(data_cls, GraphData):
        node_idx = torch.zeros(data.x.shape[0], dtype=torch.long)
        edge_idx = torch.zeros(data.e.shape[0], dtype=torch.long)
    else:
        raise ValueError("data must be a subclass of {} or {}".format(
            GraphBatch.__class__.__name__, GraphData.__class__.__name__))

    with torch.no_grad():
        # we count the number of nodes in each graph using node_idx
        gidx, n_nodes = torch.unique(node_idx, return_counts=True, sorted=True)
        _, n_edges = torch.unique(edge_idx, return_counts=True, sorted=True)

        eidx = 0
        nidx = 0
        graph_edges_list = []
        new_edges_lengths = torch.zeros(gidx.shape[0], dtype=torch.long)

        for _gidx, _n_nodes, _n_edges in zip(gidx, n_nodes, n_edges):
            graph_edges = data.edges[:, eidx:eidx + _n_edges]
            if kind == UNDIRECTED:
                missing_edges = edges_difference(graph_edges.flip(0),
                                                 graph_edges)
            elif kind == COMPLETE:
                all_graph_edges = _create_all_edges(nidx, _n_nodes)
                missing_edges = edges_difference(all_graph_edges, graph_edges)
            elif kind == SELF:
                self_edges = torch.cat(
                    [torch.arange(nidx, nidx + _n_nodes).expand(1, -1)] * 2)
                missing_edges = edges_difference(self_edges, graph_edges)
            graph_edges_list.append(missing_edges)

            if not missing_edges.shape[0]:
                new_edges_lengths[_gidx] = 0
            else:
                new_edges_lengths[_gidx] = missing_edges.shape[1]

            nidx += _n_nodes
            eidx += _n_edges

        new_edges = torch.cat(graph_edges_list, axis=1)
        new_edge_idx = gidx.repeat_interleave(new_edges_lengths)
        new_edge_attr = torch.full((new_edges.shape[1], data.e.shape[1]),
                                   fill_value=fill_value)

        edges = torch.cat([data.edges, new_edges], axis=1)
        edge_idx = torch.cat([edge_idx, new_edge_idx])
        edge_attr = torch.cat([data.e, new_edge_attr], axis=0)

        idx = edge_idx.argsort()

    if issubclass(data_cls, GraphBatch):
        return GraphBatch(
            node_attr=data.x.detach().clone(),
            edge_attr=edge_attr[idx],
            global_attr=data.g.detach().clone(),
            edges=edges[:, idx],
            node_idx=data.node_idx.detach().clone(),
            edge_idx=edge_idx[idx],
        )
    elif issubclass(data_cls, GraphData):
        return data_cls(
            node_attr=data.x.detach().clone(),
            edge_attr=edge_attr[idx],
            global_attr=data.g.detach().clone(),
            edges=edges[:, idx],
        )
Example #8
0
 def _forward_out(self, data):
     e, x, g = self.output_transform(data)
     return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)
Example #9
0
 def _forward_decode(self, data):
     e, x, g = self.decoder(data)
     return GraphBatch(x, e, g, data.edges, data.node_idx, data.edge_idx)