Пример #1
0
    def __call__(self, data: HeteroData) -> HeteroData:
        edge_types = data.edge_types  # save original edge types
        data.metapath_dict = {}

        for j, metapath in enumerate(self.metapaths):
            for edge_type in metapath:
                assert data._to_canonical(
                    edge_type) in edge_types, f"'{edge_type}' not present"

            edge_type = metapath[0]
            adj1 = SparseTensor.from_edge_index(
                edge_index=data[edge_type].edge_index,
                sparse_sizes=data[edge_type].size())

            for i, edge_type in enumerate(metapath[1:]):
                adj2 = SparseTensor.from_edge_index(
                    edge_index=data[edge_type].edge_index,
                    sparse_sizes=data[edge_type].size())
                adj1 = adj1 @ adj2

            row, col, _ = adj1.coo()
            new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])
            data[new_edge_type].edge_index = torch.vstack([row, col])
            data.metapath_dict[new_edge_type] = metapath

        if self.drop_orig_edges:
            for i in edge_types:
                if self.keep_same_node_type and i[0] == i[-1]:
                    continue
                else:
                    del data[i]

        return data
Пример #2
0
    def __call__(self, data: HeteroData) -> HeteroData:
        edge_types = data.edge_types  # save original edge types
        data.metapath_dict = {}

        for j, metapath in enumerate(self.metapaths):
            for edge_type in metapath:
                assert data._to_canonical(
                    edge_type) in edge_types, f"'{edge_type}' not present"

            edge_type = metapath[0]
            edge_weight = self._get_edge_weight(data, edge_type)
            adj1 = SparseTensor.from_edge_index(
                edge_index=data[edge_type].edge_index,
                sparse_sizes=data[edge_type].size(), edge_attr=edge_weight)

            if self.max_sample is not None:
                adj1 = self.sample_adj(adj1)

            for i, edge_type in enumerate(metapath[1:]):
                edge_weight = self._get_edge_weight(data, edge_type)
                adj2 = SparseTensor.from_edge_index(
                    edge_index=data[edge_type].edge_index,
                    sparse_sizes=data[edge_type].size(), edge_attr=edge_weight)

                adj1 = adj1 @ adj2

                if self.max_sample is not None:
                    adj1 = self.sample_adj(adj1)

            row, col, edge_weight = adj1.coo()
            new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])
            data[new_edge_type].edge_index = torch.vstack([row, col])
            if self.weighted:
                data[new_edge_type].edge_weight = edge_weight
            data.metapath_dict[new_edge_type] = metapath

        if self.drop_orig_edges:
            for i in edge_types:
                if self.keep_same_node_type and i[0] == i[-1]:
                    continue
                else:
                    del data[i]

        # remove nodes not connected by any edge type.
        if self.drop_unconnected_nodes:
            new_edge_types = data.edge_types
            node_types = data.node_types
            connected_nodes = set()
            for i in new_edge_types:
                connected_nodes.add(i[0])
                connected_nodes.add(i[-1])
            for node in node_types:
                if node not in connected_nodes:
                    del data[node]

        return data