def graph_random_splits(dataset: InMemoryDataset, train_ratio: float = 0.2, val_ratio: float = 0.4, seed: _typing.Optional[int] = ...): r"""Splitting graph dataset with specific ratio for train/val/test. Parameters ---------- dataset: ``InMemoryStaticGraphSet`` train_ratio : float the portion of data that used for training. val_ratio : float the portion of data that used for validation. seed : int random seed for splitting dataset. """ _rng_state = torch.get_rng_state() if isinstance(seed, int): torch.manual_seed(seed) perm = torch.randperm(len(dataset)) train_index = perm[:int(len(dataset) * train_ratio)] val_index = ( perm[int(len(dataset) * train_ratio):int(len(dataset) * (train_ratio + val_ratio))]) test_index = perm[int(len(dataset) * (train_ratio + val_ratio)):] dataset.train_index = train_index.tolist() dataset.val_index = val_index.tolist() dataset.test_index = test_index.tolist() torch.set_rng_state(_rng_state) return dataset
def graph_cross_validation(dataset: InMemoryDataset, n_splits: int = 10, shuffle: bool = True, random_seed: _typing.Optional[int] = ..., stratify: bool = False) -> InMemoryDataset: r"""Cross validation for graph classification data Parameters ---------- dataset : InMemoryDataset dataset with multiple graphs. n_splits : int the number of folds to split. shuffle : bool shuffle or not for sklearn.model_selection.StratifiedKFold random_seed : int random_state for sklearn.model_selection.StratifiedKFold stratify: bool """ if not isinstance(n_splits, int): raise TypeError elif not n_splits > 0: raise ValueError if not isinstance(shuffle, bool): raise TypeError if not (random_seed in (Ellipsis, None) or isinstance(random_seed, int)): raise TypeError elif isinstance(random_seed, int) and random_seed >= 0: _random_seed: int = random_seed else: _random_seed: int = random.randrange(0, 65536) if not isinstance(stratify, bool): raise TypeError if stratify: kf = StratifiedKFold(n_splits=n_splits, shuffle=shuffle, random_state=_random_seed) else: kf = KFold(n_splits=n_splits, shuffle=shuffle, random_state=_random_seed) dataset_y = [ g.data['y' if 'y' in g.data else 'label'].item() for g in dataset ] idx_list = [(train_index.tolist(), test_index.tolist()) for train_index, test_index in kf.split( np.zeros(len(dataset)), np.array(dataset_y))] dataset.folds = idx_list dataset.train_index = idx_list[0][0] dataset.val_index = idx_list[0][1] return dataset
def to_dgl_dataset( dataset: _typing.Union[Dataset, _typing.Iterable[GeneralStaticGraph]] ) -> Dataset[_typing.Union[dgl.DGLGraph, _typing.Tuple[dgl.DGLGraph, torch.Tensor]]]: def _transform( general_static_graph: GeneralStaticGraph ) -> _typing.Union[dgl.DGLGraph, _typing.Tuple[dgl.DGLGraph, torch.Tensor]]: if not isinstance(general_static_graph, GeneralStaticGraph): raise TypeError if 'label' in general_static_graph.data: label: _typing.Optional[ torch.Tensor] = general_static_graph.data['label'] elif 'y' in general_static_graph.data: label: _typing.Optional[ torch.Tensor] = general_static_graph.data['y'] else: label: _typing.Optional[torch.Tensor] = None if label is not None and isinstance( label, torch.Tensor) and torch.is_tensor(label): return conversion.general_static_graph_to_dgl_graph( general_static_graph), label else: return conversion.general_static_graph_to_dgl_graph( general_static_graph) transformed_datalist: _typing.MutableSequence[_typing.Union[ dgl.DGLGraph, _typing.Tuple[dgl.DGLGraph, torch.Tensor]]] = [] for item in dataset: if isinstance(item, GeneralStaticGraph): transformed_datalist.append(_transform(item)) elif isinstance(item, dgl.DGLGraph): transformed_datalist.append(item) elif (isinstance(item, _typing.Sequence) and len(item) == 2 and isinstance(item[0], dgl.DGLGraph) and isinstance(item[1], torch.Tensor)): transformed_datalist.append(tuple(item)) else: raise ValueError(f"Illegal data item as {item}") return (InMemoryDataset(transformed_datalist, dataset.train_index, dataset.val_index, dataset.test_index, dataset.schema) if isinstance( dataset, InMemoryDataset) else InMemoryDataset(transformed_datalist))
def split_edges(dataset: _typing.Iterable, train_ratio: float, val_ratio: _typing.Optional[float] = ...) -> Dataset: if isinstance(val_ratio, float) and not 0 < train_ratio + val_ratio < 1: raise ValueError elif not 0 < train_ratio < 1: raise ValueError if (autogl.backend.DependentBackend.is_pyg() and not (isinstance(val_ratio, float) and 0 < val_ratio < 1)): raise ValueError( "For PyG as backend, val_ratio MUST be specific float between 0 and 1, " "i.e. 0 < val_ratio < 1") return (InMemoryDataset([ split_edges_for_data(item, train_ratio, val_ratio) for item in dataset ], dataset.train_index, dataset.val_index, dataset.test_index, dataset.schema) if isinstance(dataset, Dataset) else InMemoryDataset([ split_edges_for_data(item, train_ratio, val_ratio) for item in dataset ]))
def to_pyg_dataset( dataset: _typing.Union[Dataset, _typing.Iterable[GeneralStaticGraph]] ) -> Dataset[torch_geometric.data.Data]: transformed_datalist: _typing.MutableSequence[ torch_geometric.data.Data] = [] for item in dataset: if isinstance(item, torch_geometric.data.Data): transformed_datalist.append(item) elif isinstance(item, GeneralStaticGraph): transformed_datalist.append( conversion.static_graph_to_pyg_data(item)) elif (isinstance(item, _typing.Mapping) and all([(isinstance(k, str) and isinstance(v, torch.Tensor)) for k, v in item.items()])): transformed_datalist.append(torch_geometric.data.Data(**item)) else: raise NotImplementedError( f"Unsupported data item {type(item)}<{item}> to convert as " f"{torch_geometric.data.Data}") return (InMemoryDataset(transformed_datalist, dataset.train_index, dataset.val_index, dataset.test_index, dataset.schema) if isinstance( dataset, InMemoryDataset) else InMemoryDataset(transformed_datalist))
def set_fold(dataset: InMemoryDataset, fold_id: int) -> InMemoryDataset: r"""Set fold for graph dataset consist of multiple graphs. Parameters ---------- dataset: `autogl.data.InMemoryDataset` dataset with multiple graphs. fold_id: `int` The fold in to use, MUST be in [0, dataset.n_splits) Returns ------- `autogl.data.InMemoryDataset` The reference of original dataset. """ if not (hasattr(dataset, 'folds') and dataset.folds is not None): raise ValueError("Dataset do NOT contain folds") if not 0 <= fold_id < len(dataset.folds): raise ValueError( f"Fold id {fold_id} exceed total cross validation split number {len(dataset.folds)}" ) dataset.train_index = dataset.folds[fold_id].train_index dataset.val_index = dataset.folds[fold_id].val_index return dataset
def graph_get_split( dataset, mask: str = "train", is_loader: bool = True, batch_size: int = 128, num_workers: int = 0, shuffle: bool = False ) -> _typing.Union[torch.utils.data.DataLoader, _typing.Iterable]: r"""Get train/test dataset/dataloader after cross validation. Parameters ---------- dataset: dataset with multiple graphs. mask : str is_loader : bool return original dataset or data loader batch_size : int batch_size for generating Dataloader num_workers : int number of workers parameter for data loader shuffle: bool whether shuffle the dataloader """ if not isinstance(mask, str): raise TypeError elif mask.lower() not in ("train", "val", "test"): raise ValueError if not isinstance(is_loader, bool): raise TypeError if not isinstance(batch_size, int): raise TypeError elif not batch_size > 0: raise ValueError if not isinstance(num_workers, int): raise TypeError elif not num_workers >= 0: raise ValueError if mask.lower() not in ("train", "val", "test"): raise ValueError elif mask.lower() == "train": optional_dataset_split = dataset.train_split if optional_dataset_split is None: raise ValueError(f"Provided dataset do NOT have {mask} split") else: sub_dataset = InMemoryDataset( optional_dataset_split, train_index=list(range(len(optional_dataset_split)))) elif mask.lower() == "val": optional_dataset_split = dataset.val_split if optional_dataset_split is None: raise ValueError(f"Provided dataset do NOT have {mask} split") else: sub_dataset = InMemoryDataset( optional_dataset_split, val_index=list(range(len(optional_dataset_split)))) elif mask.lower() == "test": optional_dataset_split = dataset.test_split if optional_dataset_split is None: raise ValueError(f"Provided dataset do NOT have {mask} split") else: sub_dataset = InMemoryDataset( optional_dataset_split, test_index=list(range(len(optional_dataset_split)))) else: raise ValueError( f"The provided mask parameter must be a str in ['train', 'val', 'test'], " f"illegal provided value is [{mask}]") if not is_loader: return sub_dataset if is_loader: if not (_backend.DependentBackend.is_dgl() or _backend.DependentBackend.is_pyg()): raise RuntimeError("Unsupported backend") elif _backend.DependentBackend.is_dgl(): from dgl.dataloading.pytorch import GraphDataLoader return GraphDataLoader(sub_dataset, **{ "batch_size": batch_size, "num_workers": num_workers }, shuffle=shuffle) elif _backend.DependentBackend.is_pyg(): _sub_dataset: _typing.Any = optional_dataset_split import torch_geometric if int(torch_geometric.__version__.split('.')[0]) >= 2: # version 2.x from torch_geometric.loader import DataLoader else: from torch_geometric.data import DataLoader return DataLoader(_sub_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle) else: return sub_dataset