Пример #1
0
 def train_sequence(self, index):
     index = T.asintarr(index)
     labels = self.graph.labels[index]
     sequence = SAGEMiniBatchSequence(
         [self.feature_inputs, self.structure_inputs, index], labels,
         n_samples=self.n_samples, device=self.device)
     return sequence
Пример #2
0
 def train_sequence(self, index):
     index = T.asintarr(index)
     labels = self.graph.labels[index]
     sequence = FullBatchNodeSequence(
         [self.feature_inputs, *self.structure_inputs, index],
         labels,
         device=self.device)
     return sequence
Пример #3
0
 def train_sequence(self, index):
     index = T.astensor(T.asintarr(index))
     labels = self.graph.labels[index]
     
     if self.kind == "T":
         feature_inputs = tf.gather(self.feature_inputs, index)
     else:
         feature_inputs = self.feature_inputs[index]
     sequence = FullBatchNodeSequence(feature_inputs, labels, device=self.device)
     return sequence
Пример #4
0
    def test_sequence(self, index):
        index = T.asintarr(index)
        labels = self.graph.labels[index]
        structure_inputs = self.structure_inputs[index]

        sequence = FastGCNBatchSequence(
            [self.feature_inputs, structure_inputs],
            labels,
            batch_size=None,
            rank=None,
            device=self.device)  # use full batch
        return sequence
Пример #5
0
    def train_sequence(self, index):
        index = T.asintarr(index)
        labels = self.graph.labels[index]

        sequence = SBVATSampleSequence(
            [self.feature_inputs, self.structure_inputs, index],
            labels,
            neighbors=self.neighbors,
            n_samples=self.n_samples,
            device=self.device)

        return sequence
Пример #6
0
    def train_sequence(self, index):
        index = T.asintarr(index)
        labels = self.graph.labels[index]
        adj_matrix = self.graph.adj_matrix[index][:, index]
        adj_matrix = self.adj_transform(adj_matrix)

        feature_inputs = tf.gather(self.feature_inputs, index)
        sequence = FastGCNBatchSequence([feature_inputs, adj_matrix],
                                        labels,
                                        batch_size=self.batch_size,
                                        rank=self.rank,
                                        device=self.device)
        return sequence
Пример #7
0
    def test(self, index, verbose=1):
        """
            Test the output accuracy for the `index` of nodes or `sequence`.

        Note:
        ----------
        You must compile your model before training/testing/predicting.
        Use `model.build()`.

        Parameters:
        ----------
        index: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`
            The index of nodes (or sequence) that will be tested.


        Return:
        ----------
        loss: Float scalar
            Output loss of forward propagation.
        accuracy: Float scalar
            Output accuracy of prediction.
        """

        if not self.model:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        if isinstance(index, Sequence):
            test_data = index
        else:
            index = asintarr(index)
            test_data = self.test_sequence(index)
            self.idx_test = index

        if verbose:
            print("Testing...")

        stateful_metrics = {"test_acc", 'test_loss', 'time'}
        progbar = Progbar(target=len(test_data),
                          verbose=verbose,
                          stateful_metrics=stateful_metrics)
        begin_time = time.perf_counter()
        loss, accuracy = self.test_step(test_data)
        time_passed = time.perf_counter() - begin_time
        progbar.update(len(test_data), [('test_loss', loss),
                                        ('test_acc', accuracy),
                                        ('time', time_passed)])
        return loss, accuracy
Пример #8
0
    def train_sequence(self, index, batch_size=np.inf):
        index = T.asintarr(index)
        mask = T.indices2mask(index, self.graph.n_nodes)
        index = get_indice_graph(self.structure_inputs, index, batch_size)
        while index.size < self.k:
            index = get_indice_graph(self.structure_inputs, index)

        structure_inputs = self.structure_inputs[index][:, index]
        feature_inputs = self.feature_inputs[index]
        mask = mask[index]
        labels = self.graph.labels[index[mask]]

        sequence = FullBatchNodeSequence(
            [feature_inputs, structure_inputs, mask],
            labels,
            device=self.device)
        return sequence
Пример #9
0
    def predict(self, index=None, return_prob=True):
        """
        Predict the output probability for the input node index.


        Note:
        ----------
        You must compile your model before training/testing/predicting.
            Use `model.build()`.

        Parameters:
        ----------
        index: Numpy 1D array, optional.
            The indices of nodes to predict.
            if None, predict the all nodes.

        return_prob: bool.
            whether to return the probability of prediction.

        Return:
        ----------
        The predicted probability of each class for each node,
            shape (n_nodes, n_classes).

        """

        if not self.model:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        if index is None:
            index = np.arange(self.graph.n_nodes, dtype=intx())
        else:
            index = asintarr(index)
        sequence = self.predict_sequence(index)
        logit = self.predict_step(sequence)
        if return_prob:
            logit = softmax(logit)
        return logit
Пример #10
0
    def test(self, index):
        """
            Test the output accuracy for the `index` of nodes or `sequence`.

        Note:
        ----------
        You must compile your model before training/testing/predicting.
        Use `model.build()`.

        Parameters:
        ----------
        index: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`
            The index of nodes (or sequence) that will be tested.


        Return:
        ----------
        loss: Float scalar
            Output loss of forward propagation.
        accuracy: Float scalar
            Output accuracy of prediction.
        """

        # TODO record test logs like self.train()
        if not self.model:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        if isinstance(index, Sequence):
            test_data = index
        else:
            index = asintarr(index)
            test_data = self.test_sequence(index)
            self.idx_test = index

        loss, accuracy = self.test_step(test_data)

        return loss, accuracy
Пример #11
0
    def predict(self, index):
        index = T.asintarr(index)
        mask = T.indices2mask(index, self.graph.n_nodes)

        orders_dict = {idx: order for order, idx in enumerate(index)}
        batch_idx, orders = [], []
        batch_x, batch_adj = [], []
        for cluster in range(self.n_clusters):
            nodes = self.cluster_member[cluster]
            mini_mask = mask[nodes]
            batch_nodes = np.asarray(nodes)[mini_mask]
            if batch_nodes.size == 0:
                continue
            batch_x.append(self.batch_x[cluster])
            batch_adj.append(self.batch_adj[cluster])
            batch_idx.append(np.where(mini_mask)[0])
            orders.append([orders_dict[n] for n in batch_nodes])

        batch_data = tuple(zip(batch_x, batch_adj, batch_idx))

        logit = np.zeros((index.size, self.graph.n_classes), dtype=self.floatx)
        batch_data = T.astensors(batch_data, device=self.device)

        model = self.model
        if self.kind == "P":
            model.eval()
            with torch.no_grad():
                for order, inputs in zip(orders, batch_data):
                    output = model(inputs).detach().cpu().numpy()
                    logit[order] = output
        else:
            with tf.device(self.device):
                for order, inputs in zip(orders, batch_data):
                    output = model.predict_on_batch(inputs)
                    logit[order] = output

        return logit
Пример #12
0
    def train_sequence(self, index):
        index = T.asintarr(index)
        mask = T.indices2mask(index, self.graph.n_nodes)
        labels = self.graph.labels

        batch_idx, batch_labels = [], []
        batch_x, batch_adj = [], []
        for cluster in range(self.n_clusters):
            nodes = self.cluster_member[cluster]
            mini_mask = mask[nodes]
            mini_labels = labels[nodes][mini_mask]
            if mini_labels.size == 0:
                continue
            batch_x.append(self.batch_x[cluster])
            batch_adj.append(self.batch_adj[cluster])
            batch_idx.append(np.where(mini_mask)[0])
            batch_labels.append(mini_labels)

        batch_data = tuple(zip(batch_x, batch_adj, batch_idx))

        sequence = MiniBatchSequence(batch_data,
                                     batch_labels,
                                     device=self.device)
        return sequence
Пример #13
0
    def train(self,
              idx_train,
              idx_val=None,
              epochs=200,
              early_stopping=None,
              verbose=0,
              save_best=True,
              weight_path=None,
              as_model=False,
              monitor='val_acc',
              early_stop_metric='val_loss',
              callbacks=None,
              **kwargs):
        """Train the model for the input `idx_train` of nodes or `sequence`.

        Note:
        ----------
        You must compile your model before training/testing/predicting. Use `model.build()`.

        Parameters:
        ----------
        idx_train: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`
            The index of nodes (or sequence) that will be used during training.
        idx_val: Numpy array-like, `list`, Integer scalar or
            `graphgallery.Sequence`, optional
            The index of nodes (or sequence) that will be used for validation.
            (default :obj: `None`, i.e., do not use validation during training)
        epochs: Positive integer
            The number of epochs of training.(default :obj: `200`)
        early_stopping: Positive integer or None
            The number of early stopping patience during training. (default :obj: `None`,
            i.e., do not use early stopping during training)
        verbose: int in {0, 1, 2, 3, 4}
                'verbose=0': not verbose; 
                'verbose=1': Progbar (one line, detailed); 
                'verbose=2': Progbar (one line, omitted); 
                'verbose=3': Progbar (multi line, detailed); 
                'verbose=4': Progbar (multi line, omitted); 
            (default :obj: 0)
        save_best: bool
            Whether to save the best weights (accuracy of loss depend on `monitor`)
            of training or validation (depend on `validation` is `False` or `True`).
            (default :bool: `True`)
        weight_path: String or None
            The path of saved weights/model. (default :obj: `None`, i.e.,
            `./log/{self.name}_weights`)
        as_model: bool
            Whether to save the whole model or weights only, if `True`, the `self.custom_objects`
            must be speficied if you are using custom `layer` or `loss` and so on.
        monitor: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for `save_best`. (default :obj: `val_acc`)
        early_stop_metric: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for early stopping. (default :obj: `val_loss`)
        callbacks: tensorflow.keras.callbacks. (default :obj: `None`)
        kwargs: other keyword Parameters.

        Return:
        ----------
        A `tf.keras.callbacks.History` object. Its `History.history` attribute is
            a record of training loss values and metrics values
            at successive epochs, as well as validation loss values
            and validation metrics values (if applicable).

        """
        raise_if_kwargs(kwargs)
        if not (isinstance(verbose, int) and 0 <= verbose <= 4):
            raise ValueError("'verbose=0': not verbose"
                             "'verbose=1': Progbar(one line, detailed), "
                             "'verbose=2': Progbar(one line, omitted), "
                             "'verbose=3': Progbar(multi line, detailed), "
                             "'verbose=4': Progbar(multi line, omitted), "
                             f"but got {verbose}")
        model = self.model
        # Check if model has been built
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        if isinstance(idx_train, Sequence):
            train_data = idx_train
        else:
            idx_train = asintarr(idx_train)
            train_data = self.train_sequence(idx_train)
            self.idx_train = idx_train

        validation = idx_val is not None

        if validation:
            if isinstance(idx_val, Sequence):
                val_data = idx_val
            else:
                idx_val = asintarr(idx_val)
                val_data = self.test_sequence(idx_val)
                self.idx_val = idx_val
        else:
            monitor = 'acc' if monitor[:3] == 'val' else monitor

        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(callbacks)

        history = History()
        callbacks.append(history)

        if early_stopping:
            es_callback = EarlyStopping(monitor=early_stop_metric,
                                        patience=early_stopping,
                                        mode='auto',
                                        verbose=kwargs.pop('es_verbose', 1))
            callbacks.append(es_callback)

        if save_best:
            if not weight_path:
                weight_path = self.weight_path
            else:
                self.weight_path = weight_path

            makedirs_from_filename(weight_path)

            if not weight_path.endswith(POSTFIX):
                weight_path = weight_path + POSTFIX

            mc_callback = ModelCheckpoint(weight_path,
                                          monitor=monitor,
                                          save_best_only=True,
                                          save_weights_only=not as_model,
                                          verbose=0)
            callbacks.append(mc_callback)

        callbacks.set_model(model)
        model.stop_training = False
        callbacks.on_train_begin()

        if verbose:
            stateful_metrics = {"acc", 'loss', 'val_acc', 'val_loss', 'time'}
            if verbose <= 2:
                progbar = Progbar(target=epochs,
                                  verbose=verbose,
                                  stateful_metrics=stateful_metrics)
            print("Training...")

        begin_time = time.perf_counter()
        try:
            for epoch in range(epochs):
                if verbose > 2:
                    progbar = Progbar(target=len(train_data),
                                      verbose=verbose - 2,
                                      stateful_metrics=stateful_metrics)

                callbacks.on_epoch_begin(epoch)
                callbacks.on_train_batch_begin(0)
                loss, accuracy = self.train_step(train_data)

                training_logs = {'loss': loss, 'acc': accuracy}
                if validation:
                    val_loss, val_accuracy = self.test_step(val_data)
                    training_logs.update({
                        'val_loss': val_loss,
                        'val_acc': val_accuracy
                    })
                    val_data.on_epoch_end()

                callbacks.on_train_batch_end(len(train_data), training_logs)
                callbacks.on_epoch_end(epoch, training_logs)

                train_data.on_epoch_end()

                if verbose:
                    time_passed = time.perf_counter() - begin_time
                    training_logs.update({'time': time_passed})
                    if verbose > 2:
                        print(f"Epoch {epoch+1}/{epochs}")
                        progbar.update(len(train_data), training_logs.items())
                    else:
                        progbar.update(epoch + 1, training_logs.items())

                if model.stop_training:
                    break

        finally:
            callbacks.on_train_end()
            # to avoid unexpected termination of the model
            if save_best:
                self.load(weight_path, as_model=as_model)
                self.remove_weights()

        return history
Пример #14
0
    def train_v2(self,
                 idx_train,
                 idx_val=None,
                 epochs=200,
                 early_stopping=None,
                 verbose=False,
                 save_best=True,
                 weight_path=None,
                 as_model=False,
                 monitor='val_acc',
                 early_stop_metric='val_loss',
                 callbacks=None,
                 **kwargs):
        """
            Train the model for the input `idx_train` of nodes or `sequence`.

        Note:
        ----------
        You must compile your model before training/testing/predicting. Use `model.build()`.

        Parameters:
        ----------
        idx_train: Numpy array-like, `list`, Integer scalar or
            `graphgallery.Sequence`.
            The index of nodes (or sequence) that will be used during training.
        idx_val: Numpy array-like, `list`, Integer scalar or
            `graphgallery.Sequence`, optional
            The index of nodes (or sequence) that will be used for validation.
            (default :obj: `None`, i.e., do not use validation during training)
        epochs: Positive integer
            The number of epochs of training.(default :obj: `200`)
        early_stopping: Positive integer or None
            The number of early stopping patience during training. (default :obj: `None`,
            i.e., do not use early stopping during training)
        verbose: bool
            Whether to show the training details. (default :obj: `None`)
        save_best: bool
            Whether to save the best weights (accuracy of loss depend on `monitor`)
            of training or validation (depend on `validation` is `False` or `True`).
            (default :bool: `True`)
        weight_path: String or None
            The path of saved weights/model. (default :obj: `None`, i.e.,
            `./log/{self.name}_weights`)
        as_model: bool
            Whether to save the whole model or weights only, if `True`, the `self.custom_objects`
            must be speficied if you are using customized `layer` or `loss` and so on.
        monitor: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for `save_best`. (default :obj: `val_acc`)
        early_stop_metric: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for early stopping. (default :obj: `val_loss`)
        callbacks: tensorflow.keras.callbacks. (default :obj: `None`)
        kwargs: other keyword Parameters.

        Return:
        ----------
        A `tf.keras.callbacks.History` object. Its `History.history` attribute is
            a record of training loss values and metrics values
            at successive epochs, as well as validation loss values
            and validation metrics values (if applicable).
        """

        if not tf.__version__ >= '2.2.0':
            raise RuntimeError(
                f'This method is only work for tensorflow version >= 2.2.0.')

        # Check if model has been built
        if self.model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        if isinstance(idx_train, Sequence):
            train_data = idx_train
        else:
            idx_train = asintarr(idx_train)
            train_data = self.train_sequence(idx_train)
            self.idx_train = idx_train

        validation = idx_val is not None

        if validation:
            if isinstance(idx_val, Sequence):
                val_data = idx_val
            else:
                idx_val = asintarr(idx_val)
                val_data = self.test_sequence(idx_val)
                self.idx_val = idx_val
        else:
            monitor = 'acc' if monitor[:3] == 'val' else monitor

        model = self.model
        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(callbacks,
                                                      add_history=True,
                                                      add_progbar=True,
                                                      verbose=verbose,
                                                      epochs=epochs)
        if early_stopping:
            es_callback = EarlyStopping(monitor=early_stop_metric,
                                        patience=early_stopping,
                                        mode='auto',
                                        verbose=kwargs.pop('es_verbose', 0))
            callbacks.append(es_callback)

        if save_best:
            if not weight_path:
                weight_path = self.weight_path

            makedirs_from_path(weight_path)

            if not weight_path.endswith('.h5'):
                weight_path += '.h5'

            mc_callback = ModelCheckpoint(weight_path,
                                          monitor=monitor,
                                          save_best_only=True,
                                          save_weights_only=not as_model,
                                          verbose=0)
            callbacks.append(mc_callback)
        callbacks.set_model(model)

        # leave it blank for the future
        allowed_kwargs = set([])
        unknown_kwargs = set(kwargs.keys()) - allowed_kwargs
        if unknown_kwargs:
            raise TypeError("Invalid keyword argument(s): %s" %
                            (unknown_kwargs, ))

        callbacks.on_train_begin()

        for epoch in range(epochs):
            callbacks.on_epoch_begin(epoch)

            callbacks.on_train_batch_begin(0)
            loss, accuracy = self.train_step(train_data)
            train_data.on_epoch_end()

            training_logs = {'loss': loss, 'acc': accuracy}
            callbacks.on_train_batch_end(0, training_logs)

            if validation:

                val_loss, val_accuracy = self.test_step(val_data)
                training_logs.update({
                    'val_loss': val_loss,
                    'val_acc': val_accuracy
                })
                val_data.on_epoch_end()

            callbacks.on_epoch_end(epoch, training_logs)

            if model.stop_training:
                break

        callbacks.on_train_end()

        if save_best:
            self.load(weight_path, as_model=as_model)
            remove_tf_weights(weight_path)

        return model.history
Пример #15
0
 def test(self, index):
     index = asintarr(index)
     y_true = self.graph.labels[index]
     y_pred = self.classifier.predict(self.embeddings[index])
     accuracy = accuracy_score(y_true, y_pred)
     return accuracy
Пример #16
0
 def predict(self, index):
     index = asintarr(index)
     logit = self.classifier.predict_proba(self.embeddings[index])
     return logit
Пример #17
0
    def train(self,
              idx_train,
              idx_val=None,
              epochs=200,
              early_stopping=None,
              verbose=0,
              save_best=True,
              weight_path=None,
              as_model=False,
              monitor='val_acc',
              early_stop_metric='val_loss',
              callbacks=None,
              **kwargs):
        """Train the model for the input `idx_train` of nodes or `sequence`.

        Note:
        ----------
        You must compile your model before training/testing/predicting. Use `model.build()`.

        Parameters:
        ----------
        idx_train: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`
            The index of nodes (or sequence) that will be used during training.
        idx_val: Numpy array-like, `list`, Integer scalar or
            `graphgallery.Sequence`, optional
            The index of nodes (or sequence) that will be used for validation.
            (default :obj: `None`, i.e., do not use validation during training)
        epochs: Positive integer
            The number of epochs of training.(default :obj: `200`)
        early_stopping: Positive integer or None
            The number of early stopping patience during training. (default :obj: `None`,
            i.e., do not use early stopping during training)
        verbose: int in {0, 1, 2}
                'verbose=0': not verbose; 
                'verbose=1': tqdm verbose; 
                'verbose=2': tensorflow probar verbose;        
            (default :obj: 0)
        save_best: bool
            Whether to save the best weights (accuracy of loss depend on `monitor`)
            of training or validation (depend on `validation` is `False` or `True`).
            (default :bool: `True`)
        weight_path: String or None
            The path of saved weights/model. (default :obj: `None`, i.e.,
            `./log/{self.name}_weights`)
        as_model: bool
            Whether to save the whole model or weights only, if `True`, the `self.custom_objects`
            must be speficied if you are using customized `layer` or `loss` and so on.
        monitor: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for `save_best`. (default :obj: `val_acc`)
        early_stop_metric: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for early stopping. (default :obj: `val_loss`)
        callbacks: tensorflow.keras.callbacks. (default :obj: `None`)
        kwargs: other keyword Parameters.

        Return:
        ----------
        A `tf.keras.callbacks.History` object. Its `History.history` attribute is
            a record of training loss values and metrics values
            at successive epochs, as well as validation loss values
            and validation metrics values (if applicable).

        """
        if not verbose in {0, 1, 2}:
            raise ValueError(
                "'verbose=0': not verbose; 'verbose=1': tqdm verbose; "
                "'verbose=2': tensorflow probar verbose; "
                f"but got {verbose}")
        model = self.model
        # Check if model has been built
        if model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        # TODO: add metric names in `model`
        metric_names = ['loss', 'acc']
        callback_metrics = metric_names
        model.stop_training = False

        if isinstance(idx_train, Sequence):
            train_data = idx_train
        else:
            idx_train = asintarr(idx_train)
            train_data = self.train_sequence(idx_train)
            self.idx_train = idx_train

        validation = idx_val is not None

        if validation:
            if isinstance(idx_val, Sequence):
                val_data = idx_val
            else:
                idx_val = asintarr(idx_val)
                val_data = self.test_sequence(idx_val)
                self.idx_val = idx_val
            callback_metrics = copy.copy(metric_names)
            callback_metrics += ['val_' + n for n in metric_names]
        else:
            monitor = 'acc' if monitor[:3] == 'val' else monitor

        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(callbacks)

        history = tf_History()
        callbacks.append(history)

        if verbose == 2:
            callbacks.append(ProgbarLogger(stateful_metrics=metric_names[1:]))

        if early_stopping:
            es_callback = EarlyStopping(monitor=early_stop_metric,
                                        patience=early_stopping,
                                        mode='auto',
                                        verbose=kwargs.pop('es_verbose', 1))
            callbacks.append(es_callback)

        if save_best:
            if not weight_path:
                weight_path = self.weight_path

            makedirs_from_path(weight_path)

            if not weight_path.endswith('.h5'):
                weight_path = weight_path + '.h5'

            mc_callback = ModelCheckpoint(weight_path,
                                          monitor=monitor,
                                          save_best_only=True,
                                          save_weights_only=not as_model,
                                          verbose=0)
            callbacks.append(mc_callback)

        callbacks.set_model(model)
        # TODO: to be improved
        callback_params = {
            'batch_size': None,
            'epochs': epochs,
            'steps': 1,
            'samples': 1,
            'verbose': verbose == 2,
            'do_validation': validation,
            'metrics': callback_metrics,
        }
        callbacks.set_params(callback_params)
        raise_if_kwargs(kwargs)

        callbacks.on_train_begin()

        if verbose == 1:
            pbar = tqdm(range(1, epochs + 1))
        else:
            pbar = range(epochs)

        for epoch in pbar:
            callbacks.on_epoch_begin(epoch)

            callbacks.on_train_batch_begin(0)
            loss, accuracy = self.train_step(train_data)

            training_logs = {'loss': loss, 'acc': accuracy}

            if validation:
                val_loss, val_accuracy = self.test_step(val_data)
                training_logs.update({
                    'val_loss': val_loss,
                    'val_acc': val_accuracy
                })
                val_data.on_epoch_end()
            callbacks.on_train_batch_end(0, training_logs)
            callbacks.on_epoch_end(epoch, training_logs)

            if verbose == 1:
                msg = "<"
                for key, val in training_logs.items():
                    msg += f"{key.title()} = {val:.4f} "
                msg += ">"
                pbar.set_description(msg)
            train_data.on_epoch_end()

            if verbose == 2:
                print()

            if model.stop_training:
                break

        callbacks.on_train_end()

        if save_best:
            self.load(weight_path, as_model=as_model)
            remove_tf_weights(weight_path)

        return history
Пример #18
0
    def train(self, index):
        if not self.embeddings:
            self.get_embeddings()

        index = asintarr(index)
        self.classifier.fit(self.embeddings[index], self.graph.labels[index])
Пример #19
0
    def train_v1(self,
                 idx_train,
                 idx_val=None,
                 epochs=200,
                 early_stopping=None,
                 verbose=False,
                 save_best=True,
                 weight_path=None,
                 as_model=False,
                 monitor='val_acc',
                 early_stop_metric='val_loss'):
        """Train the model for the input `idx_train` of nodes or `sequence`.

        Note:
        ----------
            You must compile your model before training/testing/predicting. Use `model.build()`.

        Parameters:
        ----------
        idx_train: Numpy array-like, `list`, Integer scalar or
            `graphgallery.Sequence`.
            The index of nodes (or sequence) that will be used during training.
        idx_val: Numpy array-like, `list`, Integer scalar or
            `graphgallery.Sequence`, optional
            The index of nodes (or sequence) that will be used for validation.
            (default :obj: `None`, i.e., do not use validation during training)
        epochs: integer
            The number of epochs of training.(default :obj: `200`)
        early_stopping: integer or None
            The number of early stopping patience during training. (default :obj: `None`,
            i.e., do not use early stopping during training)
        verbose: bool
            Whether to show the training details. (default :obj: `None`)
        save_best: bool
            Whether to save the best weights (accuracy of loss depend on `monitor`)
            of training or validation (depend on `validation` is `False` or `True`).
            (default :bool: `True`)
        weight_path: String or None
            The path of saved weights/model. (default :obj: `None`, i.e.,
            `./log/{self.name}_weights`)
        as_model: bool
            Whether to save the whole model or weights only, if `True`, the `self.custom_objects`
            must be speficied if you are using customized `layer` or `loss` and so on.
        monitor: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for `save_best`. (default :obj: `val_acc`)
        early_stop_metric: String
            One of (val_loss, val_acc, loss, acc), it determines which metric will be
            used for early stopping. (default :obj: `val_loss`)

        Return:
        ----------
        history: graphgallery.utils.History
            tensorflow like `history` instance.
        """

        # Check if model has been built
        if self.model is None:
            raise RuntimeError(
                'You must compile your model before training/testing/predicting. Use `model.build()`.'
            )

        if isinstance(idx_train, Sequence):
            train_data = idx_train
        else:
            idx_train = asintarr(idx_train)
            train_data = self.train_sequence(idx_train)
            self.idx_train = idx_train

        validation = idx_val is not None

        if validation:
            if isinstance(idx_val, Sequence):
                val_data = idx_val
            else:
                idx_val = asintarr(idx_val)
                val_data = self.test_sequence(idx_val)
                self.idx_val = idx_val
        else:
            monitor = 'acc' if monitor[:3] == 'val' else monitor

        history = History(monitor_metric=monitor,
                          early_stop_metric=early_stop_metric)

        if not weight_path:
            weight_path = self.weight_path

        if validation is None:
            history.register_monitor_metric('acc')
            history.register_early_stop_metric('loss')

        if verbose:
            pbar = tqdm(range(1, epochs + 1))
        else:
            pbar = range(1, epochs + 1)

        for epoch in pbar:

            loss, accuracy = self.train_step(train_data)
            train_data.on_epoch_end()

            history.add_results(loss, 'loss')
            history.add_results(accuracy, 'acc')

            if validation:

                val_loss, val_accuracy = self.test_step(val_data)
                val_data.on_epoch_end()
                history.add_results(val_loss, 'val_loss')
                history.add_results(val_accuracy, 'val_acc')

            # record eoch and running times
            history.record_epoch(epoch)

            if save_best and history.save_best:
                self.save(weight_path, as_model=as_model)

            # early stopping
            if early_stopping and history.time_to_early_stopping(
                    early_stopping):
                msg = f'Early stopping with patience {early_stopping}.'
                if verbose:
                    pbar.set_description(msg)
                    pbar.close()
                break

            if verbose:
                msg = f'loss {loss:.2f}, acc {accuracy:.2%}'
                if validation:
                    msg += f', val_loss {val_loss:.2f}, val_acc {val_accuracy:.2%}'
                pbar.set_description(msg)

        if save_best:
            self.load(weight_path, as_model=as_model)
            if self.kind == "T":
                remove_tf_weights(weight_path)
            else:
                remove_torch_weights(weight_path)

        return history