示例#1
0
def graph_random_splits(dataset: InMemoryDataset,
                        train_ratio: float = 0.2,
                        val_ratio: float = 0.4,
                        seed: _typing.Optional[int] = ...):
    r"""Splitting graph dataset with specific ratio for train/val/test.

    Parameters
    ----------
    dataset: ``InMemoryStaticGraphSet``

    train_ratio : float
        the portion of data that used for training.

    val_ratio : float
        the portion of data that used for validation.

    seed : int
        random seed for splitting dataset.
    """
    _rng_state = torch.get_rng_state()
    if isinstance(seed, int):
        torch.manual_seed(seed)
    perm = torch.randperm(len(dataset))
    train_index = perm[:int(len(dataset) * train_ratio)]
    val_index = (
        perm[int(len(dataset) *
                 train_ratio):int(len(dataset) * (train_ratio + val_ratio))])
    test_index = perm[int(len(dataset) * (train_ratio + val_ratio)):]
    dataset.train_index = train_index.tolist()
    dataset.val_index = val_index.tolist()
    dataset.test_index = test_index.tolist()
    torch.set_rng_state(_rng_state)
    return dataset
示例#2
0
def graph_cross_validation(dataset: InMemoryDataset,
                           n_splits: int = 10,
                           shuffle: bool = True,
                           random_seed: _typing.Optional[int] = ...,
                           stratify: bool = False) -> InMemoryDataset:
    r"""Cross validation for graph classification data

    Parameters
    ----------
    dataset : InMemoryDataset
        dataset with multiple graphs.

    n_splits : int
        the number of folds to split.

    shuffle : bool
        shuffle or not for sklearn.model_selection.StratifiedKFold

    random_seed : int
        random_state for sklearn.model_selection.StratifiedKFold

    stratify: bool
    """
    if not isinstance(n_splits, int):
        raise TypeError
    elif not n_splits > 0:
        raise ValueError
    if not isinstance(shuffle, bool):
        raise TypeError
    if not (random_seed in (Ellipsis, None) or isinstance(random_seed, int)):
        raise TypeError
    elif isinstance(random_seed, int) and random_seed >= 0:
        _random_seed: int = random_seed
    else:
        _random_seed: int = random.randrange(0, 65536)
    if not isinstance(stratify, bool):
        raise TypeError

    if stratify:
        kf = StratifiedKFold(n_splits=n_splits,
                             shuffle=shuffle,
                             random_state=_random_seed)
    else:
        kf = KFold(n_splits=n_splits,
                   shuffle=shuffle,
                   random_state=_random_seed)
    dataset_y = [
        g.data['y' if 'y' in g.data else 'label'].item() for g in dataset
    ]
    idx_list = [(train_index.tolist(), test_index.tolist())
                for train_index, test_index in kf.split(
                    np.zeros(len(dataset)), np.array(dataset_y))]

    dataset.folds = idx_list
    dataset.train_index = idx_list[0][0]
    dataset.val_index = idx_list[0][1]
    return dataset
示例#3
0
def to_dgl_dataset(
    dataset: _typing.Union[Dataset, _typing.Iterable[GeneralStaticGraph]]
) -> Dataset[_typing.Union[dgl.DGLGraph, _typing.Tuple[dgl.DGLGraph,
                                                       torch.Tensor]]]:
    def _transform(
        general_static_graph: GeneralStaticGraph
    ) -> _typing.Union[dgl.DGLGraph, _typing.Tuple[dgl.DGLGraph,
                                                   torch.Tensor]]:
        if not isinstance(general_static_graph, GeneralStaticGraph):
            raise TypeError
        if 'label' in general_static_graph.data:
            label: _typing.Optional[
                torch.Tensor] = general_static_graph.data['label']
        elif 'y' in general_static_graph.data:
            label: _typing.Optional[
                torch.Tensor] = general_static_graph.data['y']
        else:
            label: _typing.Optional[torch.Tensor] = None
        if label is not None and isinstance(
                label, torch.Tensor) and torch.is_tensor(label):
            return conversion.general_static_graph_to_dgl_graph(
                general_static_graph), label
        else:
            return conversion.general_static_graph_to_dgl_graph(
                general_static_graph)

    transformed_datalist: _typing.MutableSequence[_typing.Union[
        dgl.DGLGraph, _typing.Tuple[dgl.DGLGraph, torch.Tensor]]] = []
    for item in dataset:
        if isinstance(item, GeneralStaticGraph):
            transformed_datalist.append(_transform(item))
        elif isinstance(item, dgl.DGLGraph):
            transformed_datalist.append(item)
        elif (isinstance(item, _typing.Sequence) and len(item) == 2
              and isinstance(item[0], dgl.DGLGraph)
              and isinstance(item[1], torch.Tensor)):
            transformed_datalist.append(tuple(item))
        else:
            raise ValueError(f"Illegal data item as {item}")

    return (InMemoryDataset(transformed_datalist, dataset.train_index,
                            dataset.val_index, dataset.test_index,
                            dataset.schema) if isinstance(
                                dataset, InMemoryDataset) else
            InMemoryDataset(transformed_datalist))
示例#4
0
def split_edges(dataset: _typing.Iterable,
                train_ratio: float,
                val_ratio: _typing.Optional[float] = ...) -> Dataset:
    if isinstance(val_ratio, float) and not 0 < train_ratio + val_ratio < 1:
        raise ValueError
    elif not 0 < train_ratio < 1:
        raise ValueError
    if (autogl.backend.DependentBackend.is_pyg()
            and not (isinstance(val_ratio, float) and 0 < val_ratio < 1)):
        raise ValueError(
            "For PyG as backend, val_ratio MUST be specific float between 0 and 1, "
            "i.e. 0 < val_ratio < 1")
    return (InMemoryDataset([
        split_edges_for_data(item, train_ratio, val_ratio) for item in dataset
    ], dataset.train_index, dataset.val_index, dataset.test_index,
                            dataset.schema)
            if isinstance(dataset, Dataset) else InMemoryDataset([
                split_edges_for_data(item, train_ratio, val_ratio)
                for item in dataset
            ]))
示例#5
0
def to_pyg_dataset(
    dataset: _typing.Union[Dataset, _typing.Iterable[GeneralStaticGraph]]
) -> Dataset[torch_geometric.data.Data]:
    transformed_datalist: _typing.MutableSequence[
        torch_geometric.data.Data] = []
    for item in dataset:
        if isinstance(item, torch_geometric.data.Data):
            transformed_datalist.append(item)
        elif isinstance(item, GeneralStaticGraph):
            transformed_datalist.append(
                conversion.static_graph_to_pyg_data(item))
        elif (isinstance(item, _typing.Mapping)
              and all([(isinstance(k, str) and isinstance(v, torch.Tensor))
                       for k, v in item.items()])):
            transformed_datalist.append(torch_geometric.data.Data(**item))
        else:
            raise NotImplementedError(
                f"Unsupported data item {type(item)}<{item}> to convert as "
                f"{torch_geometric.data.Data}")
    return (InMemoryDataset(transformed_datalist, dataset.train_index,
                            dataset.val_index, dataset.test_index,
                            dataset.schema) if isinstance(
                                dataset, InMemoryDataset) else
            InMemoryDataset(transformed_datalist))
示例#6
0
def set_fold(dataset: InMemoryDataset, fold_id: int) -> InMemoryDataset:
    r"""Set fold for graph dataset consist of multiple graphs.

    Parameters
    ----------
    dataset: `autogl.data.InMemoryDataset`
        dataset with multiple graphs.
    fold_id: `int`
        The fold in to use, MUST be in [0, dataset.n_splits)

    Returns
    -------
    `autogl.data.InMemoryDataset`
        The reference of original dataset.
    """
    if not (hasattr(dataset, 'folds') and dataset.folds is not None):
        raise ValueError("Dataset do NOT contain folds")
    if not 0 <= fold_id < len(dataset.folds):
        raise ValueError(
            f"Fold id {fold_id} exceed total cross validation split number {len(dataset.folds)}"
        )
    dataset.train_index = dataset.folds[fold_id].train_index
    dataset.val_index = dataset.folds[fold_id].val_index
    return dataset
示例#7
0
def graph_get_split(
    dataset,
    mask: str = "train",
    is_loader: bool = True,
    batch_size: int = 128,
    num_workers: int = 0,
    shuffle: bool = False
) -> _typing.Union[torch.utils.data.DataLoader, _typing.Iterable]:
    r"""Get train/test dataset/dataloader after cross validation.

    Parameters
    ----------
    dataset:
        dataset with multiple graphs.

    mask : str

    is_loader : bool
        return original dataset or data loader

    batch_size : int
        batch_size for generating Dataloader
    num_workers : int
        number of workers parameter for data loader
    shuffle: bool
        whether shuffle the dataloader
    """
    if not isinstance(mask, str):
        raise TypeError
    elif mask.lower() not in ("train", "val", "test"):
        raise ValueError
    if not isinstance(is_loader, bool):
        raise TypeError
    if not isinstance(batch_size, int):
        raise TypeError
    elif not batch_size > 0:
        raise ValueError
    if not isinstance(num_workers, int):
        raise TypeError
    elif not num_workers >= 0:
        raise ValueError

    if mask.lower() not in ("train", "val", "test"):
        raise ValueError
    elif mask.lower() == "train":
        optional_dataset_split = dataset.train_split
        if optional_dataset_split is None:
            raise ValueError(f"Provided dataset do NOT have {mask} split")
        else:
            sub_dataset = InMemoryDataset(
                optional_dataset_split,
                train_index=list(range(len(optional_dataset_split))))
    elif mask.lower() == "val":
        optional_dataset_split = dataset.val_split
        if optional_dataset_split is None:
            raise ValueError(f"Provided dataset do NOT have {mask} split")
        else:
            sub_dataset = InMemoryDataset(
                optional_dataset_split,
                val_index=list(range(len(optional_dataset_split))))
    elif mask.lower() == "test":
        optional_dataset_split = dataset.test_split
        if optional_dataset_split is None:
            raise ValueError(f"Provided dataset do NOT have {mask} split")
        else:
            sub_dataset = InMemoryDataset(
                optional_dataset_split,
                test_index=list(range(len(optional_dataset_split))))
    else:
        raise ValueError(
            f"The provided mask parameter must be a str in ['train', 'val', 'test'], "
            f"illegal provided value is [{mask}]")
    if not is_loader:
        return sub_dataset
    if is_loader:
        if not (_backend.DependentBackend.is_dgl()
                or _backend.DependentBackend.is_pyg()):
            raise RuntimeError("Unsupported backend")
        elif _backend.DependentBackend.is_dgl():
            from dgl.dataloading.pytorch import GraphDataLoader
            return GraphDataLoader(sub_dataset,
                                   **{
                                       "batch_size": batch_size,
                                       "num_workers": num_workers
                                   },
                                   shuffle=shuffle)
        elif _backend.DependentBackend.is_pyg():
            _sub_dataset: _typing.Any = optional_dataset_split
            import torch_geometric
            if int(torch_geometric.__version__.split('.')[0]) >= 2:
                # version 2.x
                from torch_geometric.loader import DataLoader
            else:
                from torch_geometric.data import DataLoader
            return DataLoader(_sub_dataset,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              shuffle=shuffle)
    else:
        return sub_dataset