コード例 #1
0
    def process(self, **kwargs):
        """This method is used for process your inputs, which accepts
        only keyword arguments in your defined method 'process_step'.
        This method will process the inputs, and transform them into tensors.

        Commonly used keyword arguments:
        --------------------------------
        adj_transform: string, Callable function, or a tuple with function and dict arguments.
            transform for adjacency matrix.
        attr_transform: string, Callable function, or a tuple with function and dict arguments.
            transform for attribute matrix.
        graph_transform: string, Callable function, or a tuple with function and dict arguments.
            transform for the entire graph, it is used before 'adj_transform' and 'attr_transform'.        
        other arguments (if have) will be passed into your method 'process_step'.
        """
        cfg = self.cfg.process
        _, kwargs = gf.wrapper(self.process_step)(**kwargs)
        cfg.merge_from_dict(kwargs)

        for k, v in cfg.items():
            if k.endswith("transform"):
                setattr(self.transform, k, gf.get(v))

        self.is_processed = True
        return self
コード例 #2
0
ファイル: trainer.py プロジェクト: EdisonLeeeee/GraphGallery
    def make_data(self, graph, graph_transform=None, device=None, **kwargs):
        """This method is used for process your inputs, which accepts
        only keyword arguments in your defined method 'data_step'.
        This method will process the inputs, and transform them into tensors.

        Commonly used keyword arguments:
        --------------------------------
        graph: graphgallery graph classes.
        graph_transform: string, Callable function,
            or a tuple with function and dict arguments.
            transform for the entire graph, it is used first.
        device: device for preparing data, if None, it defaults to `self.device`
        adj_transform: string, Callable function,
            or a tuple with function and dict arguments.
            transform for adjacency matrix.
        attr_transform: string, Callable function,
            or a tuple with function and dict arguments.
            transform for attribute matrix.
        other arguments (if have) will be passed into method 'data_step'.
        """
        self.graph = gf.get(graph_transform)(graph)
        cfg = self.cfg.data
        if device is not None:
            self.data_device = gf.device(device, self.backend)
        else:
            self.data_device = self.device
        cfg.device = device
        _, kwargs = gf.wrapper(self.data_step)(**kwargs)
        kwargs['graph_transform'] = graph_transform
        cfg.merge_from_dict(kwargs)

        for k, v in kwargs.items():
            if k.endswith("transform"):
                setattr(self.transform, k, gf.get(v))
        return self
コード例 #3
0
ファイル: trainer.py プロジェクト: EdisonLeeeee/GraphGallery
    def build(self, **kwargs):
        """This method is used for build your model, which
        accepts only keyword arguments in your defined method 'model_step'.

        Note:
        -----
        This method should be called after `process`.

        Commonly used keyword arguments:
        --------------------------------
        hids: int or a list of them,
            hidden units for each hidden layer.
        acts: string or a list of them,
            activation functions for each layer.
        dropout: float scalar,
            dropout used in the model.
        lr: float scalar,
            learning rate used for the model.
        weight_decay: float scalar,
            weight decay used for the model weights.
        bias: bool,
            whether to use bias in each layer.
        use_tfn: bool,
            this argument is only used for TensorFlow backend, if `True`, it will decorate
            the model training and testing with `tf.function` (See `graphgallery.nn.modes.TFKeras`).
            By default, it was `True`, which can accelerate the training and inference, by it may cause
            several errors.
        other arguments (if have) will be passed into your method 'model_step'.
        """
        if self._graph is None:
            raise RuntimeError("Please call 'trainer.make_data(graph)' first.")

        use_tfn = kwargs.get("use_tfn", True)
        if self.backend == "tensorflow":
            with tf.device(self.device):
                self.model, kwargs = gf.wrapper(self.model_step)(**kwargs)
                if use_tfn:
                    self.model.use_tfn()
        else:
            kwargs.pop("use_tfn", None)
            model, kwargs = gf.wrapper(self.model_step)(**kwargs)
            self.model = model.to(self.device)
        self.cfg.model.merge_from_dict(kwargs)
        return self
コード例 #4
0
    def setup_graph(self, graph, graph_transform=None, device=None, **kwargs):
        """This method is used for process your inputs, which accepts
        only keyword arguments in your defined method 'data_step'.
        This method will process the inputs, and transform them into tensors.

        Commonly used keyword arguments:
        --------------------------------
        graph: graphgallery graph instance.
            the input graph
        graph_transform: string, Callable function,
            or a tuple with function and dict arguments.
            transform for the entire graph, it is used first.
        device: device for preparing data, if None, it defaults to `self.device`
        adj_transform: string, Callable function,
            or a tuple with function and dict arguments.
            transform for adjacency matrix.
        feat_transform: string, Callable function,
            or a tuple with function and dict arguments.
            transform for attribute (feature) matrix.
        other arguments (if have) will be passed into method 'data_step'.
        """
        self.cache_clear()

        attr_transform = kwargs.pop("attr_transform", None)

        if attr_transform:
            warnings.warn("Argument 'attr_transform' is deprecated and will removed in future version, "
                          "please use 'feat_transform' instead.")
            kwargs['feat_transform'] = attr_transform

        self.graph = gf.get(graph_transform)(graph)
        if device is not None:
            self.data_device = torch.device(device)
        else:
            self.data_device = self.device
        _, kwargs = gf.wrapper(self.data_step)(**kwargs)
        kwargs['graph_transform'] = graph_transform

        for k, v in kwargs.items():
            if k.endswith("transform"):
                setattr(self.transform, k, gf.get(v))

        return self
コード例 #5
0
    def build(self, **kwargs):
        """This method is used for build your model, which
        accepts only keyword arguments in your defined method 'model_step'.

        Note:
        -----
        This method should be called after `setup_graph`.

        Commonly used keyword arguments:
        --------------------------------
        hids: int or a list of them,
            hidden units for each hidden layer.
        acts: string or a list of them,
            activation functions for each layer.
        dropout: float scalar,
            dropout used in the model.
        lr: float scalar,
            learning rate used for the model.
        weight_decay: float scalar,
            weight decay used for the model weights.
        bias: bool,
            whether to use bias in each layer.
        other arguments (if have) will be passed into your method 'model_step'.
        """
        if self._graph is None:
            raise RuntimeError("Please call 'trainer.setup_graph(graph)' first.")

        model, kwargs = gf.wrapper(self.model_step)(**kwargs)
        self._model = model.to(self.device)

        self.optimizer = self.config_optimizer()
        self.scheduler = self.config_scheduler(self.optimizer)
        self.loss = self.config_loss()
        metrics = self.config_metrics()

        if not isinstance(metrics, list):
            metrics = [metrics]

        self.metrics = metrics
        return self
コード例 #6
0
 def build(self, **kwargs):
     self.model, kwargs = gf.wrapper(self.model_builder)(**kwargs)
     self.cfg.model.merge_from_dict(kwargs)
     self.classifier = self.classifier_builder()
     return self