Ejemplo n.º 1
0
    def to_homogeneous(self,
                       node_attrs: Optional[List[str]] = None,
                       edge_attrs: Optional[List[str]] = None,
                       add_node_type: bool = True,
                       add_edge_type: bool = True) -> Data:
        """Converts a :class:`~torch_geometric.data.HeteroData` object to a
        homogeneous :class:`~torch_geometric.data.Data` object.
        By default, all features with same feature dimensionality across
        different types will be merged into a single representation, unless
        otherwise specified via the :obj:`node_attrs` and :obj:`edge_attrs`
        arguments.
        Furthermore, attributes named :obj:`node_type` and :obj:`edge_type`
        will be added to the returned :class:`~torch_geometric.data.Data`
        object, denoting node-level and edge-level vectors holding the
        node and edge type as integers, respectively.

        Args:
            node_attrs (List[str], optional): The node features to combine
                across all node types. These node features need to be of the
                same feature dimensionality. If set to :obj:`None`, will
                automatically determine which node features to combine.
                (default: :obj:`None`)
            edge_attrs (List[str], optional): The edge features to combine
                across all edge types. These edge features need to be of the
                same feature dimensionality. If set to :obj:`None`, will
                automatically determine which edge features to combine.
                (default: :obj:`None`)
            add_node_type (bool, optional): If set to :obj:`False`, will not
                add the node-level vector :obj:`node_type` to the returned
                :class:`~torch_geometric.data.Data` object.
                (default: :obj:`True`)
            add_edge_type (bool, optional): If set to :obj:`False`, will not
                add the edge-level vector :obj:`edge_type` to the returned
                :class:`~torch_geometric.data.Data` object.
                (default: :obj:`True`)
        """
        def _consistent_size(stores: List[BaseStorage]) -> List[str]:
            sizes_dict = defaultdict(list)
            for store in stores:
                for key, value in store.items():
                    if key in ['edge_index', 'adj_t']:
                        continue
                    if isinstance(value, Tensor):
                        dim = self.__cat_dim__(key, value, store)
                        size = value.size()[:dim] + value.size()[dim + 1:]
                        sizes_dict[key].append(tuple(size))
            return [
                k for k, sizes in sizes_dict.items()
                if len(sizes) == len(stores) and len(set(sizes)) == 1
            ]

        data = Data(**self._global_store.to_dict())

        # Iterate over all node stores and record the slice information:
        node_slices, cumsum = {}, 0
        node_type_names, node_types = [], []
        for i, (node_type, store) in enumerate(self._node_store_dict.items()):
            num_nodes = store.num_nodes
            node_slices[node_type] = (cumsum, cumsum + num_nodes)
            node_type_names.append(node_type)
            cumsum += num_nodes

            if add_node_type:
                kwargs = {'dtype': torch.long}
                node_types.append(torch.full((num_nodes, ), i, **kwargs))
        data._node_type_names = node_type_names

        if len(node_types) > 1:
            data.node_type = torch.cat(node_types, dim=0)
        elif len(node_types) == 1:
            data.node_type = node_types[0]

        # Combine node attributes into a single tensor:
        if node_attrs is None:
            node_attrs = _consistent_size(self.node_stores)
        for key in node_attrs:
            values = [store[key] for store in self.node_stores]
            dim = self.__cat_dim__(key, values[0], self.node_stores[0])
            value = torch.cat(values, dim) if len(values) > 1 else values[0]
            data[key] = value

        if len([
                key for key in node_attrs
                if (key in {'x', 'pos', 'batch'} or 'node' in key)
        ]) == 0 and not add_node_type:
            data.num_nodes = cumsum

        # Iterate over all edge stores and record the slice information:
        edge_slices, cumsum = {}, 0
        edge_indices, edge_type_names, edge_types = [], [], []
        for i, (edge_type, store) in enumerate(self._edge_store_dict.items()):
            src, _, dst = edge_type
            num_edges = store.num_edges
            edge_slices[edge_type] = (cumsum, cumsum + num_edges)
            edge_type_names.append(edge_type)
            cumsum += num_edges

            kwargs = {'dtype': torch.long, 'device': store.edge_index.device}
            offset = [[node_slices[src][0]], [node_slices[dst][0]]]
            offset = torch.tensor(offset, **kwargs)
            edge_indices.append(store.edge_index + offset)
            if add_edge_type:
                edge_types.append(torch.full((num_edges, ), i, **kwargs))
        data._edge_type_names = edge_type_names

        if len(edge_indices) > 1:
            data.edge_index = torch.cat(edge_indices, dim=-1)
        elif len(edge_indices) == 1:
            data.edge_index = edge_indices[0]

        if len(edge_types) > 1:
            data.edge_type = torch.cat(edge_types, dim=0)
        elif len(edge_types) == 1:
            data.edge_type = edge_types[0]

        # Combine edge attributes into a single tensor:
        if edge_attrs is None:
            edge_attrs = _consistent_size(self.edge_stores)
        for key in edge_attrs:
            values = [store[key] for store in self.edge_stores]
            dim = self.__cat_dim__(key, values[0], self.edge_stores[0])
            value = torch.cat(values, dim) if len(values) > 1 else values[0]
            data[key] = value

        return data
Ejemplo n.º 2
0
    def to_homogeneous(self,
                       node_attrs: Optional[List[str]] = None,
                       edge_attrs: Optional[List[str]] = None,
                       add_node_type: bool = True,
                       add_edge_type: bool = True) -> Data:
        """Converts a :class:`~torch_geometric.data.HeteroData` object to a
        homogeneous :class:`~torch_geometric.data.Data` object.
        By default, all features with same feature dimensionality across
        different types will be merged into a single representation, unless
        otherwise specified via the :obj:`node_attrs` and :obj:`edge_attrs`
        arguments.
        Furthermore, attributes named :obj:`node_type` and :obj:`edge_type`
        will be added to the returned :class:`~torch_geometric.data.Data`
        object, denoting node-level and edge-level vectors holding the
        node and edge type as integers, respectively.

        Args:
            node_attrs (List[str], optional): The node features to combine
                across all node types. These node features need to be of the
                same feature dimensionality. If set to :obj:`None`, will
                automatically determine which node features to combine.
                (default: :obj:`None`)
            edge_attrs (List[str], optional): The edge features to combine
                across all edge types. These edge features need to be of the
                same feature dimensionality. If set to :obj:`None`, will
                automatically determine which edge features to combine.
                (default: :obj:`None`)
            add_node_type (bool, optional): If set to :obj:`False`, will not
                add the node-level vector :obj:`node_type` to the returned
                :class:`~torch_geometric.data.Data` object.
                (default: :obj:`True`)
            add_edge_type (bool, optional): If set to :obj:`False`, will not
                add the edge-level vector :obj:`edge_type` to the returned
                :class:`~torch_geometric.data.Data` object.
                (default: :obj:`True`)
        """
        def _consistent_size(stores: List[BaseStorage]) -> List[str]:
            sizes_dict = defaultdict(list)
            for store in stores:
                for key, value in store.items():
                    if key in ['edge_index', 'adj_t']:
                        continue
                    if isinstance(value, Tensor):
                        dim = self.__cat_dim__(key, value, store)
                        size = value.size()[:dim] + value.size()[dim + 1:]
                        sizes_dict[key].append(tuple(size))
            return [
                k for k, sizes in sizes_dict.items()
                if len(sizes) == len(stores) and len(set(sizes)) == 1
            ]

        edge_index, node_slices, edge_slices = to_homogeneous_edge_index(self)
        device = edge_index.device if edge_index is not None else None

        data = Data(**self._global_store.to_dict())
        if edge_index is not None:
            data.edge_index = edge_index
        data._node_type_names = list(node_slices.keys())
        data._edge_type_names = list(edge_slices.keys())

        # Combine node attributes into a single tensor:
        if node_attrs is None:
            node_attrs = _consistent_size(self.node_stores)
        for key in node_attrs:
            values = [store[key] for store in self.node_stores]
            dim = self.__cat_dim__(key, values[0], self.node_stores[0])
            value = torch.cat(values, dim) if len(values) > 1 else values[0]
            data[key] = value

        if not data.can_infer_num_nodes:
            data.num_nodes = list(node_slices.values())[-1][1]

        # Combine edge attributes into a single tensor:
        if edge_attrs is None:
            edge_attrs = _consistent_size(self.edge_stores)
        for key in edge_attrs:
            values = [store[key] for store in self.edge_stores]
            dim = self.__cat_dim__(key, values[0], self.edge_stores[0])
            value = torch.cat(values, dim) if len(values) > 1 else values[0]
            data[key] = value

        if add_node_type:
            sizes = [offset[1] - offset[0] for offset in node_slices.values()]
            sizes = torch.tensor(sizes, dtype=torch.long, device=device)
            node_type = torch.arange(len(sizes), device=device)
            data.node_type = node_type.repeat_interleave(sizes)

        if add_edge_type and edge_index is not None:
            sizes = [offset[1] - offset[0] for offset in edge_slices.values()]
            sizes = torch.tensor(sizes, dtype=torch.long, device=device)
            edge_type = torch.arange(len(sizes), device=device)
            data.edge_type = edge_type.repeat_interleave(sizes)

        return data