Пример #1
0
 def available_datasets():
     return gf.BunchDict(reddit="reddit dataset")
Пример #2
0
    def fit(self, train_data, val_data=None, **kwargs):

        cache = self.cache
        cfg = self.cfg.fit
        cfg.merge_from_dict(kwargs)
        ckpt_cfg = cfg.ModelCheckpoint
        es_cfg = cfg.EarlyStopping
        pb_cfg = cfg.Progbar
        log_cfg = cfg.Logger

        if log_cfg.enabled:
            log_cfg.name = log_cfg.name or self.name
            logger = gg.utils.setup_logger(output=log_cfg.filepath,
                                           name=log_cfg.name)

        model = self.model
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `trainer.build()`.'
            )

        if not isinstance(train_data, (Sequence, DataLoader, Dataset)):
            train_data = self.train_loader(train_data)

        if cfg.cache_train_data:
            cache.train_data = train_data

        validation = val_data is not None
        if validation:
            if not isinstance(val_data, (Sequence, DataLoader, Dataset)):
                val_data = self.test_loader(val_data)
            if cfg.cache_val_data:
                cache.val_data = val_data

        # Setup callbacks
        callbacks = callbacks_module.CallbackList()
        history = History()
        callbacks.append(history)
        cfg, callbacks = setup_callbacks(cfg, callbacks, validation)
        callbacks.set_model(model)
        model.stop_training = False

        verbose = cfg.verbose
        assert not (
            verbose and log_cfg.enabled
        ), "Progbar and Logger cannot be used together! You must set `verbose=0` when Logger is enabled."

        if verbose:
            if verbose <= 2:
                progbar = Progbar(target=cfg.epochs,
                                  width=pb_cfg.width,
                                  verbose=verbose)
            print("Training...")
        elif log_cfg.enabled:
            logger.info("Training...")

        logs = gf.BunchDict()
        callbacks.on_train_begin()
        try:
            for epoch in range(cfg.epochs):
                if verbose > 2:
                    progbar = Progbar(target=len(train_data),
                                      width=pb_cfg.width,
                                      verbose=verbose - 2)

                callbacks.on_epoch_begin(epoch)
                callbacks.on_train_batch_begin(0)
                train_logs = self.train_step(train_data)
                if hasattr(train_data, 'on_epoch_end'):
                    train_data.on_epoch_end()
                logs.update(train_logs)

                if validation:
                    valid_logs = self.test_step(val_data)
                    logs.update({("val_" + k): v
                                 for k, v in valid_logs.items()})
                    if hasattr(val_data, 'on_epoch_end'):
                        val_data.on_epoch_end()

                callbacks.on_train_batch_end(len(train_data), logs)
                callbacks.on_epoch_end(epoch, logs)

                if verbose > 2:
                    print(f"Epoch {epoch+1}/{cfg.epochs}")
                    progbar.update(len(train_data), logs.items())
                elif verbose:
                    progbar.update(epoch + 1, logs.items())
                elif log_cfg.enabled:
                    logger.info(
                        f"Epoch {epoch+1}/{cfg.epochs}\n{gg.utils.create_table(logs)}"
                    )

                if model.stop_training:
                    if log_cfg.enabled:
                        logger.info(f"Early Stopping at Epoch {epoch}")
                    else:
                        print(f"Early Stopping at Epoch {epoch}",
                              file=sys.stderr)
                    break

            callbacks.on_train_end()
            if ckpt_cfg.enabled:
                if ckpt_cfg.save_weights_only:
                    model.load_weights(ckpt_cfg.path)
                else:
                    self.model = model.load(ckpt_cfg.path)
        finally:
            # to avoid unexpected termination of the model
            if ckpt_cfg.enabled and ckpt_cfg.remove_weights:
                self.remove_weights()

        return history
Пример #3
0
 def test(self, index):
     index = gf.asarray(index)
     y_true = self.graph.node_label[index]
     y_pred = self.classifier.predict(self.embeddings[index])
     accuracy = accuracy_score(y_true, y_pred)
     return gf.BunchDict(loss=None, accuracy=accuracy)
Пример #4
0
 def empty_cache(self):
     self.cache = gf.BunchDict()
Пример #5
0
import os
import os.path as osp

from typing import Optional, List
from graphgallery import functional as gf

from .in_memory_dataset import InMemoryDataset
from ..data.io import makedirs, download_file
from ..data.preprocess import process_planetoid_datasets
from ..data.graph import Graph

_DATASETS = gf.BunchDict({"citeseer": "citeseer citation dataset",
                          "cora": "cora citation dataset",
                          "pubmed": "pubmed citation dataset",
                          "nell.0.1": "NELL dataset",
                          "nell.0.01": "NELL dataset",
                          "nell.0.001": "NELL dataset", })

_DATASET_URL = "https://github.com/EdisonLeeeee/" + \
    "GraphData/raw/master/datasets/planetoid"


class Planetoid(InMemoryDataset):
    r"""The citation network datasets "Cora", "CiteSeer" and "PubMed" from the
    `"Revisiting Semi-Supervised Learning with Graph Embeddings"
    <https://arxiv.org/abs/1603.08861>`_ paper.
    Nodes represent documents and edges represent citation links.
    Training, validation and test splits are given by binary masks.

    The original url is: <https://github.com/kimiyoung/planetoid/raw/master/data>
    """
Пример #6
0
 def flips(self):
     # TODO
     return gf.BunchDict(edge_flips=self.edge_flips, nx_flips=self.nx_flips)
Пример #7
0
import os.path as osp
import pickle as pkl
from graphgallery import functional as gf

from ..data import Reader
from ..data.graph import Graph
from ..data.multi_graph import MultiGraph
from .in_memory_dataset import InMemoryDataset

_DATASET = gf.BunchDict(
    deezer="deezer dataset (node-level)",
    facebook="facebook dataset (node-level)",
    github="github dataset (node-level)",
    lastfm="lastfm dataset (node-level)",
    twitch="twitch dataset (node-level)",
    wikipedia="wikipedia dataset (node-level)",
    reddit10k="reddit10k dataset (graph-level)",
)


class KarateClub(InMemoryDataset):
    """Datasets from `Karate Club: An API Oriented Open-source Python Framework for Unsupervised Learning on Graphs`, CIKM 2020
    <https://github.com/benedekrozemberczki/karateclub>
    """

    __node_level_url__ = "https://github.com/EdisonLeeeee/GraphData/raw/master/datasets/karateclub/node_level"
    __graph_level_url__ = "https://github.com/EdisonLeeeee/GraphData/raw/master/datasets/karateclub/graph_level"

    def __init__(self,
                 name,
                 root=None,
Пример #8
0
 def available_datasets():
     return gf.BunchDict(ppi="ppi dataset")