예제 #1
0
    def test(self, data, **kwargs):

        if not self.model:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `trainer.build()`.'
            )

        cache = self.cache
        cfg = self.cfg.test
        cfg.merge_from_dict(kwargs)

        if isinstance(data, Sequence):
            test_data = data
        else:
            test_data = self.test_sequence(data)

        cache.test_data = test_data

        if cfg.verbose:
            print("Testing...")

        progbar = Progbar(target=len(test_data),
                          width=cfg.Progbar.width,
                          verbose=cfg.verbose)
        logs = gf.BunchDict(**self.test_step(test_data))
        logs.update({k: v.numpy().item() for k, v in logs.items()})
        progbar.update(len(test_data), logs.items())
        return logs
예제 #2
0
    def test(self, data, verbose=1):
        """Test the output accuracy for the data.

        Note:
        ----------
        You must compile your model before training/testing/predicting.
        Use `model.build()`.

        Parameters:
        ----------
        data: Numpy array-like, `list` or `graphgallery.Sequence`
            The index of objects (or sequence) that will be tested.


        Return:
        ----------
        loss: Float scalar
            Output loss of forward propagation.
        accuracy: Float scalar
            Output accuracy of prediction.
        """

        if not self.model:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        if isinstance(data, Sequence):
            test_data = data
        else:
            test_data = self.test_sequence(data)

        self.test_data = test_data

        if verbose:
            print("Testing...")

        metrics_names = self.model.metrics_names

        progbar = Progbar(target=len(test_data),
                          width=20,
                          verbose=verbose)
        logs = BunchDict(**self.test_step(test_data))
        progbar.update(len(test_data), logs.items())
        return logs
예제 #3
0
    def evaluate(self, test_data, verbose=1):

        if not self.model:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `trainer.build()`.'
            )

        if not isinstance(test_data, (DataLoader, Dataset)):
            test_data = self.config_test_data(test_data)

        if verbose:
            print("Testing...")

        progbar = Progbar(target=len(test_data),
                          verbose=verbose)
        logs = gf.BunchDict(**self.test_step(test_data))
        logs.update({k: self.to_item(v) for k, v in logs.items()})
        progbar.update(len(test_data), logs)
        return logs
예제 #4
0
    def train(self, train_data, val_data=None, **kwargs):
        cache = self.cache
        cfg = self.cfg.train
        cfg.merge_from_dict(kwargs)
        ckpt_cfg = cfg.ModelCheckpoint
        es_cfg = cfg.EarlyStopping
        pb_cfg = cfg.Progbar

        model = self.model
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `trainer.build()`.'
            )

        if not isinstance(train_data, Sequence):
            train_data = self.train_sequence(train_data)

        cache.train_data = train_data

        validation = val_data is not None

        if validation:
            if not isinstance(val_data, Sequence):
                val_data = self.test_sequence(val_data)
            cache.val_data = val_data
        elif ckpt_cfg.enabled and ckpt_cfg.monitor.startswith("val_"):
            ckpt_cfg.monitor = ckpt_cfg.monitor[4:]
            warnings.warn(
                f"The metric 'val_{ckpt_cfg.monitor}' is invalid without validation "
                f"and has been automatically replaced with '{ckpt_cfg.monitor}'.",
                UserWarning)

        callbacks = callbacks_module.CallbackList()

        history = History()
        callbacks.append(history)

        if es_cfg.enabled:
            assert es_cfg.monitor.startswith("val")
            es_callback = EarlyStopping(monitor=es_cfg.monitor,
                                        patience=es_cfg.monitor,
                                        mode=es_cfg.mode,
                                        verbose=es_cfg.verbose)
            callbacks.append(es_callback)

        if ckpt_cfg.enabled:
            if not ckpt_cfg.path.endswith(gg.file_ext()):
                ckpt_cfg.path += gg.file_ext()
            makedirs_from_filepath(ckpt_cfg.path)

            mc_callback = ModelCheckpoint(
                ckpt_cfg.path,
                monitor=ckpt_cfg.monitor,
                save_best_only=ckpt_cfg.save_best_only,
                save_weights_only=ckpt_cfg.save_weights_only,
                verbose=ckpt_cfg.vervose)
            callbacks.append(mc_callback)

        callbacks.set_model(model)
        model.stop_training = False

        verbose = cfg.verbose
        if verbose:
            if verbose <= 2:
                progbar = Progbar(target=cfg.epochs,
                                  width=pb_cfg.width,
                                  verbose=verbose)
            print("Training...")

        logs = gf.BunchDict()
        callbacks.on_train_begin()
        try:
            for epoch in range(cfg.epochs):
                if verbose > 2:
                    progbar = Progbar(target=len(train_data),
                                      width=pb_cfg.width,
                                      verbose=verbose - 2)

                callbacks.on_epoch_begin(epoch)
                callbacks.on_train_batch_begin(0)
                train_logs = self.train_step(train_data)
                train_data.on_epoch_end()
                logs.update(train_logs)

                if validation:
                    valid_logs = self.test_step(val_data)
                    logs.update({("val_" + k): v
                                 for k, v in valid_logs.items()})
                    val_data.on_epoch_end()

                callbacks.on_train_batch_end(len(train_data), logs)
                callbacks.on_epoch_end(epoch, logs)

                if verbose > 2:
                    print(f"Epoch {epoch+1}/{epochs}")
                    progbar.update(len(train_data), logs.items())
                elif verbose:
                    progbar.update(epoch + 1, logs.items())

                if model.stop_training:
                    print(f"Early Stopping at Epoch {epoch}", file=sys.stderr)
                    break

            callbacks.on_train_end()
            if ckpt_cfg.enabled:
                if ckpt_cfg.save_weights_only:
                    model.load_weights(ckpt_cfg.path)
                else:
                    self.model = model.load(ckpt_cfg.path)

        finally:
            # to avoid unexpected termination of the model
            if ckpt_cfg.enabled and ckpt_cfg.remove_weights:
                self.remove_weights()

        return history
예제 #5
0
    def train(self,
              train_data,
              val_data=None,
              epochs=200,
              early_stopping=None,
              verbose=1,
              save_best=True,
              ckpt_path=None,
              as_model=False,
              monitor='val_accuracy',
              early_stop_metric='val_loss',
              callbacks=None,
              **kwargs):
        """Train the model for the input `train_data` of nodes or `sequence`.

        Note:
        ----------
        You must compile your model before training/testing/predicting. Use `model.build()`.

        Parameters:
        ----------
        train_data: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`
            The index of objects (or sequence) that will be used during training.
        val_data: Numpy array-like, `list`, Integer scalar or
            `graphgallery.Sequence`, optional
            The index of objects (or sequence) that will be used for validation.
            (default :obj: `None`, i.e., do not use validation during training)
        epochs: Positive integer
            The number of epochs of training.(default :obj: `200`)
        early_stopping: Positive integer or None
            The number of early stopping patience during training. 
            (default :obj: `None`, i.e., do not use early stopping during training)
        verbose: int in {0, 1, 2, 3, 4}
            'verbose=0': not verbose;
            'verbose=1': Progbar (one line, detailed);
            'verbose=2': Progbar (one line, omitted);
            'verbose=3': Progbar (multi line, detailed);
            'verbose=4': Progbar (multi line, omitted);
            (default :obj: 1)
        save_best: bool
            Whether to save the best weights (accuracy of loss depend on `monitor`)
            of training or validation (depend on `validation` is `False` or `True`).
            (default :bool: `True`)
        ckpt_path: String or None
            The path of saved weights/model. 
            (default to current path.)
        as_model: bool
            Whether to save the whole model or weights only, if `True`, the `self.custom_objects`
            must be speficied if you are using custom `layer` or `loss` and so on.
        monitor: String
            One of evaluation metrics, e.g., val_loss, val_accuracy, loss, accuracy, 
            it determines which metric will be used for `save_best`. 
            (default :obj: `val_accuracy`)
        early_stop_metric: String
            One of evaluation metrics, e.g., val_loss, val_accuracy, loss, accuracy, 
            it determines which metric will be used for early stopping. 
            (default :obj: `val_loss`)
        callbacks: tensorflow.keras.callbacks. (default :obj: `None`)
        kwargs: other keyword Parameters.

        Return:
        ----------
        A `tf.keras.callbacks.History` object. Its `History.history` attribute is
            a record of training loss values and metrics values
            at successive epochs, as well as validation loss values
            and validation metrics values (if applicable).

        """
        raise_if_kwargs(kwargs)
        if not (isinstance(verbose, int) and 0 <= verbose <= 4):
            raise ValueError("'verbose=0': not verbose"
                             "'verbose=1': Progbar(one line, detailed), "
                             "'verbose=2': Progbar(one line, omitted), "
                             "'verbose=3': Progbar(multi line, detailed), "
                             "'verbose=4': Progbar(multi line, omitted), "
                             f"but got {verbose}")
        model = self.model
        # Check if model has been built
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        metrics_names = getattr(model, "metrics_names", None)
        # FIXME: This would return '[]' for tensorflow>=2.2.0
        # See <https://github.com/tensorflow/tensorflow/issues/37990>
        # metrics_names = ['loss', 'accuracy']
        if not metrics_names:
            raise RuntimeError(f"Please specify the attribute 'metrics_names' for the model.")
        if not isinstance(train_data, Sequence):
            train_data = self.train_sequence(train_data)

        self.train_data = train_data

        validation = val_data is not None

        if validation:
            if not isinstance(val_data, Sequence):
                val_data = self.test_sequence(val_data)
            self.val_data = val_data
            metrics_names = metrics_names + ["val_" + metric for metric in metrics_names]

        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(callbacks)

        history = History()
        callbacks.append(history)

        if early_stopping:
            es_callback = EarlyStopping(monitor=early_stop_metric,
                                        patience=early_stopping,
                                        mode='auto',
                                        verbose=kwargs.pop('es_verbose', 1))
            callbacks.append(es_callback)

        if save_best:
            if not ckpt_path:
                ckpt_path = self.ckpt_path
            else:
                self.ckpt_path = ckpt_path

            makedirs_from_filepath(ckpt_path)

            if not ckpt_path.endswith(gg.file_ext()):
                ckpt_path = ckpt_path + gg.file_ext()

            if monitor not in metrics_names:
                monitor = metrics_names[-1]
                warnings.warn(f"'{monitor}' are not included in the metrics names. default to '{monitor}'.",
                              UserWarning)

            mc_callback = ModelCheckpoint(ckpt_path,
                                          monitor=monitor,
                                          save_best_only=True,
                                          save_weights_only=not as_model,
                                          verbose=0)
            callbacks.append(mc_callback)

        callbacks.set_model(model)
        model.stop_training = False

        if verbose:
            if verbose <= 2:
                progbar = Progbar(target=epochs,
                                  width=20,
                                  verbose=verbose)
            print("Training...")

        logs = BunchDict()
        callbacks.on_train_begin()
        try:
            for epoch in range(epochs):
                if verbose > 2:
                    progbar = Progbar(target=len(train_data),
                                      width=20,
                                      verbose=verbose - 2)

                callbacks.on_epoch_begin(epoch)
                callbacks.on_train_batch_begin(0)
                train_logs = self.train_step(train_data)
                train_data.on_epoch_end()

                logs.update(train_logs)

                if validation:
                    valid_logs = self.test_step(val_data)
                    logs.update({("val_" + k): v for k, v in valid_logs.items()})
                    val_data.on_epoch_end()

                callbacks.on_train_batch_end(len(train_data), logs)
                callbacks.on_epoch_end(epoch, logs)

                if verbose > 2:
                    print(f"Epoch {epoch+1}/{epochs}")
                    progbar.update(len(train_data), logs.items())
                elif verbose:
                    progbar.update(epoch + 1, logs.items())

                if model.stop_training:
                    print(f"Early Stopping at Epoch {epoch}", file=sys.stderr)
                    break

            callbacks.on_train_end()
            self.load(ckpt_path, as_model=as_model)
        finally:
            # to avoid unexpected termination of the model
            self.remove_weights()

        return history
예제 #6
0
class ProgbarLogger(Callback):
    """Callback that prints metrics to stdout.
    TODO: on_[test/predict]_[begin/end] haven't been tested.
    """
    def __init__(self):
        super().__init__()
        # Defaults to all Model's metrics except for loss.
        self.seen = 0
        self.progbar = None
        self.target = None
        self.verbose = 1
        self.epochs = 1

    def set_params(self, params):
        self.verbose = params['verbose']
        self.epochs = params['epochs']
        if 0 < self.verbose <= 2:
            self.target = params['epochs']
        else:
            # Will be inferred at the end of the first epoch.
            self.target = None

    def on_train_begin(self, logs=None):
        self._reset_progbar()

    def on_test_begin(self, logs=None):
        self._reset_progbar()
        self._maybe_init_progbar()

    def on_predict_begin(self, logs=None):
        self._reset_progbar()
        self._maybe_init_progbar()

    def on_epoch_begin(self, epoch, logs=None):
        self._maybe_init_progbar()
        if self.verbose > 2 and self.epochs > 1:
            print('Epoch %d/%d' % (epoch + 1, self.epochs))

    def on_train_batch_end(self, batch, logs=None):
        self._batch_update_progbar(batch, logs)

    def on_test_batch_end(self, batch, logs=None):
        self._batch_update_progbar(batch, logs)

    def on_predict_batch_end(self, batch, logs=None):
        # Don't pass prediction results.
        self._batch_update_progbar(batch, None)

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        if self.verbose > 2:
            self._finalize_progbar(logs)
        elif self.verbose > 0:
            self.progbar.update(epoch + 1, logs)

    def on_test_end(self, logs=None):
        self._finalize_progbar(logs)

    def on_predict_end(self, logs=None):
        self._finalize_progbar(logs)

    def _reset_progbar(self):
        if self.verbose > 2:
            self.seen = 0
            self.progbar = None

    def _maybe_init_progbar(self):
        if self.progbar is None:
            self.progbar = Progbar(target=self.target,
                                   verbose=self.verbose -
                                   2 if self.verbose > 2 else self.verbose)

    def _batch_update_progbar(self, batch, logs=None):
        """Updates the progbar."""
        logs = logs or {}
        self._maybe_init_progbar()
        self.seen = batch

        if self.verbose > 2:
            # Only block async when verbose = 1.
            self.progbar.update(self.seen, logs, finalize=False)

    def _finalize_progbar(self, logs):
        logs = logs or {}
        if self.target is None:
            self.progbar.target = self.target = self.seen
        self.progbar.update(self.target, logs, finalize=True)
        self._reset_progbar()

    def __str__(self) -> str:
        return f"{self.__class__.__name__}(epochs={self.epochs}, verbose={self.verbose})"

    __repr__ = __str__
예제 #7
0
    def train(self, train_data, val_data=None, **kwargs):
        cache = self.cache
        cfg = self.cfg.train
        cfg.merge_from_dict(kwargs)
        ckpt_cfg = cfg.ModelCheckpoint
        es_cfg = cfg.EarlyStopping
        pb_cfg = cfg.Progbar

        model = self.model
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `trainer.build()`.'
            )

        if not isinstance(train_data, Sequence):
            train_data = self.train_sequence(train_data)

        if cfg.cache_train_data:
            cache.train_data = train_data

        validation = val_data is not None
        if validation:
            if not isinstance(val_data, Sequence):
                val_data = self.test_sequence(val_data)
            if cfg.cache_val_data:
                cache.val_data = val_data

        # Setup callbacks
        callbacks = callbacks_module.CallbackList()
        history = History()
        callbacks.append(history)
        cfg, callbacks = setup_callbacks(cfg, callbacks, validation)
        callbacks.set_model(model)
        model.stop_training = False

        verbose = cfg.verbose
        if verbose:
            if verbose <= 2:
                progbar = Progbar(target=cfg.epochs,
                                  width=pb_cfg.width,
                                  verbose=verbose)
            print("Training...")

        logs = gf.BunchDict()
        callbacks.on_train_begin()
        try:
            for epoch in range(cfg.epochs):
                if verbose > 2:
                    progbar = Progbar(target=len(train_data),
                                      width=pb_cfg.width,
                                      verbose=verbose - 2)

                callbacks.on_epoch_begin(epoch)
                callbacks.on_train_batch_begin(0)
                train_logs = self.train_step(train_data)
                train_data.on_epoch_end()
                logs.update(train_logs)

                if validation:
                    valid_logs = self.test_step(val_data)
                    logs.update({("val_" + k): v for k, v in valid_logs.items()})
                    val_data.on_epoch_end()

                callbacks.on_train_batch_end(len(train_data), logs)
                callbacks.on_epoch_end(epoch, logs)

                if verbose > 2:
                    print(f"Epoch {epoch+1}/{epochs}")
                    progbar.update(len(train_data), logs.items())
                elif verbose:
                    progbar.update(epoch + 1, logs.items())

                if model.stop_training:
                    print(f"Early Stopping at Epoch {epoch}", file=sys.stderr)
                    break

            callbacks.on_train_end()
            if ckpt_cfg.enabled:
                if ckpt_cfg.save_weights_only:
                    model.load_weights(ckpt_cfg.path)
                else:
                    self.model = model.load(ckpt_cfg.path)

        finally:
            # to avoid unexpected termination of the model
            if ckpt_cfg.enabled and ckpt_cfg.remove_weights:
                self.remove_weights()

        return history
예제 #8
0
    def fit(self, train_data, val_data=None, **kwargs):

        cache = self.cache
        cfg = self.cfg.fit
        cfg.merge_from_dict(kwargs)
        ckpt_cfg = cfg.ModelCheckpoint
        es_cfg = cfg.EarlyStopping
        pb_cfg = cfg.Progbar
        log_cfg = cfg.Logger

        if log_cfg.enabled:
            log_cfg.name = log_cfg.name or self.name
            logger = gg.utils.setup_logger(output=log_cfg.filepath,
                                           name=log_cfg.name)

        model = self.model
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `trainer.build()`.'
            )

        if not isinstance(train_data, (Sequence, DataLoader, Dataset)):
            train_data = self.train_loader(train_data)

        if cfg.cache_train_data:
            cache.train_data = train_data

        validation = val_data is not None
        if validation:
            if not isinstance(val_data, (Sequence, DataLoader, Dataset)):
                val_data = self.test_loader(val_data)
            if cfg.cache_val_data:
                cache.val_data = val_data

        # Setup callbacks
        callbacks = callbacks_module.CallbackList()
        history = History()
        callbacks.append(history)
        cfg, callbacks = setup_callbacks(cfg, callbacks, validation)
        callbacks.set_model(model)
        model.stop_training = False

        verbose = cfg.verbose
        assert not (
            verbose and log_cfg.enabled
        ), "Progbar and Logger cannot be used together! You must set `verbose=0` when Logger is enabled."

        if verbose:
            if verbose <= 2:
                progbar = Progbar(target=cfg.epochs,
                                  width=pb_cfg.width,
                                  verbose=verbose)
            print("Training...")
        elif log_cfg.enabled:
            logger.info("Training...")

        logs = gf.BunchDict()
        callbacks.on_train_begin()
        try:
            for epoch in range(cfg.epochs):
                if verbose > 2:
                    progbar = Progbar(target=len(train_data),
                                      width=pb_cfg.width,
                                      verbose=verbose - 2)

                callbacks.on_epoch_begin(epoch)
                callbacks.on_train_batch_begin(0)
                train_logs = self.train_step(train_data)
                if hasattr(train_data, 'on_epoch_end'):
                    train_data.on_epoch_end()
                logs.update(train_logs)

                if validation:
                    valid_logs = self.test_step(val_data)
                    logs.update({("val_" + k): v
                                 for k, v in valid_logs.items()})
                    if hasattr(val_data, 'on_epoch_end'):
                        val_data.on_epoch_end()

                callbacks.on_train_batch_end(len(train_data), logs)
                callbacks.on_epoch_end(epoch, logs)

                if verbose > 2:
                    print(f"Epoch {epoch+1}/{cfg.epochs}")
                    progbar.update(len(train_data), logs.items())
                elif verbose:
                    progbar.update(epoch + 1, logs.items())
                elif log_cfg.enabled:
                    logger.info(
                        f"Epoch {epoch+1}/{cfg.epochs}\n{gg.utils.create_table(logs)}"
                    )

                if model.stop_training:
                    if log_cfg.enabled:
                        logger.info(f"Early Stopping at Epoch {epoch}")
                    else:
                        print(f"Early Stopping at Epoch {epoch}",
                              file=sys.stderr)
                    break

            callbacks.on_train_end()
            if ckpt_cfg.enabled:
                if ckpt_cfg.save_weights_only:
                    model.load_weights(ckpt_cfg.path)
                else:
                    self.model = model.load(ckpt_cfg.path)
        finally:
            # to avoid unexpected termination of the model
            if ckpt_cfg.enabled and ckpt_cfg.remove_weights:
                self.remove_weights()

        return history