def train_sequence(self, index): index = T.asintarr(index) labels = self.graph.labels[index] sequence = SAGEMiniBatchSequence( [self.feature_inputs, self.structure_inputs, index], labels, n_samples=self.n_samples, device=self.device) return sequence
def train_sequence(self, index): index = T.asintarr(index) labels = self.graph.labels[index] sequence = FullBatchNodeSequence( [self.feature_inputs, *self.structure_inputs, index], labels, device=self.device) return sequence
def train_sequence(self, index): index = T.astensor(T.asintarr(index)) labels = self.graph.labels[index] if self.kind == "T": feature_inputs = tf.gather(self.feature_inputs, index) else: feature_inputs = self.feature_inputs[index] sequence = FullBatchNodeSequence(feature_inputs, labels, device=self.device) return sequence
def test_sequence(self, index): index = T.asintarr(index) labels = self.graph.labels[index] structure_inputs = self.structure_inputs[index] sequence = FastGCNBatchSequence( [self.feature_inputs, structure_inputs], labels, batch_size=None, rank=None, device=self.device) # use full batch return sequence
def train_sequence(self, index): index = T.asintarr(index) labels = self.graph.labels[index] sequence = SBVATSampleSequence( [self.feature_inputs, self.structure_inputs, index], labels, neighbors=self.neighbors, n_samples=self.n_samples, device=self.device) return sequence
def train_sequence(self, index): index = T.asintarr(index) labels = self.graph.labels[index] adj_matrix = self.graph.adj_matrix[index][:, index] adj_matrix = self.adj_transform(adj_matrix) feature_inputs = tf.gather(self.feature_inputs, index) sequence = FastGCNBatchSequence([feature_inputs, adj_matrix], labels, batch_size=self.batch_size, rank=self.rank, device=self.device) return sequence
def test(self, index, verbose=1): """ Test the output accuracy for the `index` of nodes or `sequence`. Note: ---------- You must compile your model before training/testing/predicting. Use `model.build()`. Parameters: ---------- index: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence` The index of nodes (or sequence) that will be tested. Return: ---------- loss: Float scalar Output loss of forward propagation. accuracy: Float scalar Output accuracy of prediction. """ if not self.model: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `model.build()`.' ) if isinstance(index, Sequence): test_data = index else: index = asintarr(index) test_data = self.test_sequence(index) self.idx_test = index if verbose: print("Testing...") stateful_metrics = {"test_acc", 'test_loss', 'time'} progbar = Progbar(target=len(test_data), verbose=verbose, stateful_metrics=stateful_metrics) begin_time = time.perf_counter() loss, accuracy = self.test_step(test_data) time_passed = time.perf_counter() - begin_time progbar.update(len(test_data), [('test_loss', loss), ('test_acc', accuracy), ('time', time_passed)]) return loss, accuracy
def train_sequence(self, index, batch_size=np.inf): index = T.asintarr(index) mask = T.indices2mask(index, self.graph.n_nodes) index = get_indice_graph(self.structure_inputs, index, batch_size) while index.size < self.k: index = get_indice_graph(self.structure_inputs, index) structure_inputs = self.structure_inputs[index][:, index] feature_inputs = self.feature_inputs[index] mask = mask[index] labels = self.graph.labels[index[mask]] sequence = FullBatchNodeSequence( [feature_inputs, structure_inputs, mask], labels, device=self.device) return sequence
def predict(self, index=None, return_prob=True): """ Predict the output probability for the input node index. Note: ---------- You must compile your model before training/testing/predicting. Use `model.build()`. Parameters: ---------- index: Numpy 1D array, optional. The indices of nodes to predict. if None, predict the all nodes. return_prob: bool. whether to return the probability of prediction. Return: ---------- The predicted probability of each class for each node, shape (n_nodes, n_classes). """ if not self.model: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `model.build()`.' ) if index is None: index = np.arange(self.graph.n_nodes, dtype=intx()) else: index = asintarr(index) sequence = self.predict_sequence(index) logit = self.predict_step(sequence) if return_prob: logit = softmax(logit) return logit
def test(self, index): """ Test the output accuracy for the `index` of nodes or `sequence`. Note: ---------- You must compile your model before training/testing/predicting. Use `model.build()`. Parameters: ---------- index: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence` The index of nodes (or sequence) that will be tested. Return: ---------- loss: Float scalar Output loss of forward propagation. accuracy: Float scalar Output accuracy of prediction. """ # TODO record test logs like self.train() if not self.model: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `model.build()`.' ) if isinstance(index, Sequence): test_data = index else: index = asintarr(index) test_data = self.test_sequence(index) self.idx_test = index loss, accuracy = self.test_step(test_data) return loss, accuracy
def predict(self, index): index = T.asintarr(index) mask = T.indices2mask(index, self.graph.n_nodes) orders_dict = {idx: order for order, idx in enumerate(index)} batch_idx, orders = [], [] batch_x, batch_adj = [], [] for cluster in range(self.n_clusters): nodes = self.cluster_member[cluster] mini_mask = mask[nodes] batch_nodes = np.asarray(nodes)[mini_mask] if batch_nodes.size == 0: continue batch_x.append(self.batch_x[cluster]) batch_adj.append(self.batch_adj[cluster]) batch_idx.append(np.where(mini_mask)[0]) orders.append([orders_dict[n] for n in batch_nodes]) batch_data = tuple(zip(batch_x, batch_adj, batch_idx)) logit = np.zeros((index.size, self.graph.n_classes), dtype=self.floatx) batch_data = T.astensors(batch_data, device=self.device) model = self.model if self.kind == "P": model.eval() with torch.no_grad(): for order, inputs in zip(orders, batch_data): output = model(inputs).detach().cpu().numpy() logit[order] = output else: with tf.device(self.device): for order, inputs in zip(orders, batch_data): output = model.predict_on_batch(inputs) logit[order] = output return logit
def train_sequence(self, index): index = T.asintarr(index) mask = T.indices2mask(index, self.graph.n_nodes) labels = self.graph.labels batch_idx, batch_labels = [], [] batch_x, batch_adj = [], [] for cluster in range(self.n_clusters): nodes = self.cluster_member[cluster] mini_mask = mask[nodes] mini_labels = labels[nodes][mini_mask] if mini_labels.size == 0: continue batch_x.append(self.batch_x[cluster]) batch_adj.append(self.batch_adj[cluster]) batch_idx.append(np.where(mini_mask)[0]) batch_labels.append(mini_labels) batch_data = tuple(zip(batch_x, batch_adj, batch_idx)) sequence = MiniBatchSequence(batch_data, batch_labels, device=self.device) return sequence
def train(self, idx_train, idx_val=None, epochs=200, early_stopping=None, verbose=0, save_best=True, weight_path=None, as_model=False, monitor='val_acc', early_stop_metric='val_loss', callbacks=None, **kwargs): """Train the model for the input `idx_train` of nodes or `sequence`. Note: ---------- You must compile your model before training/testing/predicting. Use `model.build()`. Parameters: ---------- idx_train: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence` The index of nodes (or sequence) that will be used during training. idx_val: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`, optional The index of nodes (or sequence) that will be used for validation. (default :obj: `None`, i.e., do not use validation during training) epochs: Positive integer The number of epochs of training.(default :obj: `200`) early_stopping: Positive integer or None The number of early stopping patience during training. (default :obj: `None`, i.e., do not use early stopping during training) verbose: int in {0, 1, 2, 3, 4} 'verbose=0': not verbose; 'verbose=1': Progbar (one line, detailed); 'verbose=2': Progbar (one line, omitted); 'verbose=3': Progbar (multi line, detailed); 'verbose=4': Progbar (multi line, omitted); (default :obj: 0) save_best: bool Whether to save the best weights (accuracy of loss depend on `monitor`) of training or validation (depend on `validation` is `False` or `True`). (default :bool: `True`) weight_path: String or None The path of saved weights/model. (default :obj: `None`, i.e., `./log/{self.name}_weights`) as_model: bool Whether to save the whole model or weights only, if `True`, the `self.custom_objects` must be speficied if you are using custom `layer` or `loss` and so on. monitor: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for `save_best`. (default :obj: `val_acc`) early_stop_metric: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for early stopping. (default :obj: `val_loss`) callbacks: tensorflow.keras.callbacks. (default :obj: `None`) kwargs: other keyword Parameters. Return: ---------- A `tf.keras.callbacks.History` object. Its `History.history` attribute is a record of training loss values and metrics values at successive epochs, as well as validation loss values and validation metrics values (if applicable). """ raise_if_kwargs(kwargs) if not (isinstance(verbose, int) and 0 <= verbose <= 4): raise ValueError("'verbose=0': not verbose" "'verbose=1': Progbar(one line, detailed), " "'verbose=2': Progbar(one line, omitted), " "'verbose=3': Progbar(multi line, detailed), " "'verbose=4': Progbar(multi line, omitted), " f"but got {verbose}") model = self.model # Check if model has been built if model is None: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `model.build()`.' ) if isinstance(idx_train, Sequence): train_data = idx_train else: idx_train = asintarr(idx_train) train_data = self.train_sequence(idx_train) self.idx_train = idx_train validation = idx_val is not None if validation: if isinstance(idx_val, Sequence): val_data = idx_val else: idx_val = asintarr(idx_val) val_data = self.test_sequence(idx_val) self.idx_val = idx_val else: monitor = 'acc' if monitor[:3] == 'val' else monitor if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList(callbacks) history = History() callbacks.append(history) if early_stopping: es_callback = EarlyStopping(monitor=early_stop_metric, patience=early_stopping, mode='auto', verbose=kwargs.pop('es_verbose', 1)) callbacks.append(es_callback) if save_best: if not weight_path: weight_path = self.weight_path else: self.weight_path = weight_path makedirs_from_filename(weight_path) if not weight_path.endswith(POSTFIX): weight_path = weight_path + POSTFIX mc_callback = ModelCheckpoint(weight_path, monitor=monitor, save_best_only=True, save_weights_only=not as_model, verbose=0) callbacks.append(mc_callback) callbacks.set_model(model) model.stop_training = False callbacks.on_train_begin() if verbose: stateful_metrics = {"acc", 'loss', 'val_acc', 'val_loss', 'time'} if verbose <= 2: progbar = Progbar(target=epochs, verbose=verbose, stateful_metrics=stateful_metrics) print("Training...") begin_time = time.perf_counter() try: for epoch in range(epochs): if verbose > 2: progbar = Progbar(target=len(train_data), verbose=verbose - 2, stateful_metrics=stateful_metrics) callbacks.on_epoch_begin(epoch) callbacks.on_train_batch_begin(0) loss, accuracy = self.train_step(train_data) training_logs = {'loss': loss, 'acc': accuracy} if validation: val_loss, val_accuracy = self.test_step(val_data) training_logs.update({ 'val_loss': val_loss, 'val_acc': val_accuracy }) val_data.on_epoch_end() callbacks.on_train_batch_end(len(train_data), training_logs) callbacks.on_epoch_end(epoch, training_logs) train_data.on_epoch_end() if verbose: time_passed = time.perf_counter() - begin_time training_logs.update({'time': time_passed}) if verbose > 2: print(f"Epoch {epoch+1}/{epochs}") progbar.update(len(train_data), training_logs.items()) else: progbar.update(epoch + 1, training_logs.items()) if model.stop_training: break finally: callbacks.on_train_end() # to avoid unexpected termination of the model if save_best: self.load(weight_path, as_model=as_model) self.remove_weights() return history
def train_v2(self, idx_train, idx_val=None, epochs=200, early_stopping=None, verbose=False, save_best=True, weight_path=None, as_model=False, monitor='val_acc', early_stop_metric='val_loss', callbacks=None, **kwargs): """ Train the model for the input `idx_train` of nodes or `sequence`. Note: ---------- You must compile your model before training/testing/predicting. Use `model.build()`. Parameters: ---------- idx_train: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`. The index of nodes (or sequence) that will be used during training. idx_val: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`, optional The index of nodes (or sequence) that will be used for validation. (default :obj: `None`, i.e., do not use validation during training) epochs: Positive integer The number of epochs of training.(default :obj: `200`) early_stopping: Positive integer or None The number of early stopping patience during training. (default :obj: `None`, i.e., do not use early stopping during training) verbose: bool Whether to show the training details. (default :obj: `None`) save_best: bool Whether to save the best weights (accuracy of loss depend on `monitor`) of training or validation (depend on `validation` is `False` or `True`). (default :bool: `True`) weight_path: String or None The path of saved weights/model. (default :obj: `None`, i.e., `./log/{self.name}_weights`) as_model: bool Whether to save the whole model or weights only, if `True`, the `self.custom_objects` must be speficied if you are using customized `layer` or `loss` and so on. monitor: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for `save_best`. (default :obj: `val_acc`) early_stop_metric: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for early stopping. (default :obj: `val_loss`) callbacks: tensorflow.keras.callbacks. (default :obj: `None`) kwargs: other keyword Parameters. Return: ---------- A `tf.keras.callbacks.History` object. Its `History.history` attribute is a record of training loss values and metrics values at successive epochs, as well as validation loss values and validation metrics values (if applicable). """ if not tf.__version__ >= '2.2.0': raise RuntimeError( f'This method is only work for tensorflow version >= 2.2.0.') # Check if model has been built if self.model is None: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `model.build()`.' ) if isinstance(idx_train, Sequence): train_data = idx_train else: idx_train = asintarr(idx_train) train_data = self.train_sequence(idx_train) self.idx_train = idx_train validation = idx_val is not None if validation: if isinstance(idx_val, Sequence): val_data = idx_val else: idx_val = asintarr(idx_val) val_data = self.test_sequence(idx_val) self.idx_val = idx_val else: monitor = 'acc' if monitor[:3] == 'val' else monitor model = self.model if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList(callbacks, add_history=True, add_progbar=True, verbose=verbose, epochs=epochs) if early_stopping: es_callback = EarlyStopping(monitor=early_stop_metric, patience=early_stopping, mode='auto', verbose=kwargs.pop('es_verbose', 0)) callbacks.append(es_callback) if save_best: if not weight_path: weight_path = self.weight_path makedirs_from_path(weight_path) if not weight_path.endswith('.h5'): weight_path += '.h5' mc_callback = ModelCheckpoint(weight_path, monitor=monitor, save_best_only=True, save_weights_only=not as_model, verbose=0) callbacks.append(mc_callback) callbacks.set_model(model) # leave it blank for the future allowed_kwargs = set([]) unknown_kwargs = set(kwargs.keys()) - allowed_kwargs if unknown_kwargs: raise TypeError("Invalid keyword argument(s): %s" % (unknown_kwargs, )) callbacks.on_train_begin() for epoch in range(epochs): callbacks.on_epoch_begin(epoch) callbacks.on_train_batch_begin(0) loss, accuracy = self.train_step(train_data) train_data.on_epoch_end() training_logs = {'loss': loss, 'acc': accuracy} callbacks.on_train_batch_end(0, training_logs) if validation: val_loss, val_accuracy = self.test_step(val_data) training_logs.update({ 'val_loss': val_loss, 'val_acc': val_accuracy }) val_data.on_epoch_end() callbacks.on_epoch_end(epoch, training_logs) if model.stop_training: break callbacks.on_train_end() if save_best: self.load(weight_path, as_model=as_model) remove_tf_weights(weight_path) return model.history
def test(self, index): index = asintarr(index) y_true = self.graph.labels[index] y_pred = self.classifier.predict(self.embeddings[index]) accuracy = accuracy_score(y_true, y_pred) return accuracy
def predict(self, index): index = asintarr(index) logit = self.classifier.predict_proba(self.embeddings[index]) return logit
def train(self, idx_train, idx_val=None, epochs=200, early_stopping=None, verbose=0, save_best=True, weight_path=None, as_model=False, monitor='val_acc', early_stop_metric='val_loss', callbacks=None, **kwargs): """Train the model for the input `idx_train` of nodes or `sequence`. Note: ---------- You must compile your model before training/testing/predicting. Use `model.build()`. Parameters: ---------- idx_train: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence` The index of nodes (or sequence) that will be used during training. idx_val: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`, optional The index of nodes (or sequence) that will be used for validation. (default :obj: `None`, i.e., do not use validation during training) epochs: Positive integer The number of epochs of training.(default :obj: `200`) early_stopping: Positive integer or None The number of early stopping patience during training. (default :obj: `None`, i.e., do not use early stopping during training) verbose: int in {0, 1, 2} 'verbose=0': not verbose; 'verbose=1': tqdm verbose; 'verbose=2': tensorflow probar verbose; (default :obj: 0) save_best: bool Whether to save the best weights (accuracy of loss depend on `monitor`) of training or validation (depend on `validation` is `False` or `True`). (default :bool: `True`) weight_path: String or None The path of saved weights/model. (default :obj: `None`, i.e., `./log/{self.name}_weights`) as_model: bool Whether to save the whole model or weights only, if `True`, the `self.custom_objects` must be speficied if you are using customized `layer` or `loss` and so on. monitor: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for `save_best`. (default :obj: `val_acc`) early_stop_metric: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for early stopping. (default :obj: `val_loss`) callbacks: tensorflow.keras.callbacks. (default :obj: `None`) kwargs: other keyword Parameters. Return: ---------- A `tf.keras.callbacks.History` object. Its `History.history` attribute is a record of training loss values and metrics values at successive epochs, as well as validation loss values and validation metrics values (if applicable). """ if not verbose in {0, 1, 2}: raise ValueError( "'verbose=0': not verbose; 'verbose=1': tqdm verbose; " "'verbose=2': tensorflow probar verbose; " f"but got {verbose}") model = self.model # Check if model has been built if model is None: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `model.build()`.' ) # TODO: add metric names in `model` metric_names = ['loss', 'acc'] callback_metrics = metric_names model.stop_training = False if isinstance(idx_train, Sequence): train_data = idx_train else: idx_train = asintarr(idx_train) train_data = self.train_sequence(idx_train) self.idx_train = idx_train validation = idx_val is not None if validation: if isinstance(idx_val, Sequence): val_data = idx_val else: idx_val = asintarr(idx_val) val_data = self.test_sequence(idx_val) self.idx_val = idx_val callback_metrics = copy.copy(metric_names) callback_metrics += ['val_' + n for n in metric_names] else: monitor = 'acc' if monitor[:3] == 'val' else monitor if not isinstance(callbacks, callbacks_module.CallbackList): callbacks = callbacks_module.CallbackList(callbacks) history = tf_History() callbacks.append(history) if verbose == 2: callbacks.append(ProgbarLogger(stateful_metrics=metric_names[1:])) if early_stopping: es_callback = EarlyStopping(monitor=early_stop_metric, patience=early_stopping, mode='auto', verbose=kwargs.pop('es_verbose', 1)) callbacks.append(es_callback) if save_best: if not weight_path: weight_path = self.weight_path makedirs_from_path(weight_path) if not weight_path.endswith('.h5'): weight_path = weight_path + '.h5' mc_callback = ModelCheckpoint(weight_path, monitor=monitor, save_best_only=True, save_weights_only=not as_model, verbose=0) callbacks.append(mc_callback) callbacks.set_model(model) # TODO: to be improved callback_params = { 'batch_size': None, 'epochs': epochs, 'steps': 1, 'samples': 1, 'verbose': verbose == 2, 'do_validation': validation, 'metrics': callback_metrics, } callbacks.set_params(callback_params) raise_if_kwargs(kwargs) callbacks.on_train_begin() if verbose == 1: pbar = tqdm(range(1, epochs + 1)) else: pbar = range(epochs) for epoch in pbar: callbacks.on_epoch_begin(epoch) callbacks.on_train_batch_begin(0) loss, accuracy = self.train_step(train_data) training_logs = {'loss': loss, 'acc': accuracy} if validation: val_loss, val_accuracy = self.test_step(val_data) training_logs.update({ 'val_loss': val_loss, 'val_acc': val_accuracy }) val_data.on_epoch_end() callbacks.on_train_batch_end(0, training_logs) callbacks.on_epoch_end(epoch, training_logs) if verbose == 1: msg = "<" for key, val in training_logs.items(): msg += f"{key.title()} = {val:.4f} " msg += ">" pbar.set_description(msg) train_data.on_epoch_end() if verbose == 2: print() if model.stop_training: break callbacks.on_train_end() if save_best: self.load(weight_path, as_model=as_model) remove_tf_weights(weight_path) return history
def train(self, index): if not self.embeddings: self.get_embeddings() index = asintarr(index) self.classifier.fit(self.embeddings[index], self.graph.labels[index])
def train_v1(self, idx_train, idx_val=None, epochs=200, early_stopping=None, verbose=False, save_best=True, weight_path=None, as_model=False, monitor='val_acc', early_stop_metric='val_loss'): """Train the model for the input `idx_train` of nodes or `sequence`. Note: ---------- You must compile your model before training/testing/predicting. Use `model.build()`. Parameters: ---------- idx_train: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`. The index of nodes (or sequence) that will be used during training. idx_val: Numpy array-like, `list`, Integer scalar or `graphgallery.Sequence`, optional The index of nodes (or sequence) that will be used for validation. (default :obj: `None`, i.e., do not use validation during training) epochs: integer The number of epochs of training.(default :obj: `200`) early_stopping: integer or None The number of early stopping patience during training. (default :obj: `None`, i.e., do not use early stopping during training) verbose: bool Whether to show the training details. (default :obj: `None`) save_best: bool Whether to save the best weights (accuracy of loss depend on `monitor`) of training or validation (depend on `validation` is `False` or `True`). (default :bool: `True`) weight_path: String or None The path of saved weights/model. (default :obj: `None`, i.e., `./log/{self.name}_weights`) as_model: bool Whether to save the whole model or weights only, if `True`, the `self.custom_objects` must be speficied if you are using customized `layer` or `loss` and so on. monitor: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for `save_best`. (default :obj: `val_acc`) early_stop_metric: String One of (val_loss, val_acc, loss, acc), it determines which metric will be used for early stopping. (default :obj: `val_loss`) Return: ---------- history: graphgallery.utils.History tensorflow like `history` instance. """ # Check if model has been built if self.model is None: raise RuntimeError( 'You must compile your model before training/testing/predicting. Use `model.build()`.' ) if isinstance(idx_train, Sequence): train_data = idx_train else: idx_train = asintarr(idx_train) train_data = self.train_sequence(idx_train) self.idx_train = idx_train validation = idx_val is not None if validation: if isinstance(idx_val, Sequence): val_data = idx_val else: idx_val = asintarr(idx_val) val_data = self.test_sequence(idx_val) self.idx_val = idx_val else: monitor = 'acc' if monitor[:3] == 'val' else monitor history = History(monitor_metric=monitor, early_stop_metric=early_stop_metric) if not weight_path: weight_path = self.weight_path if validation is None: history.register_monitor_metric('acc') history.register_early_stop_metric('loss') if verbose: pbar = tqdm(range(1, epochs + 1)) else: pbar = range(1, epochs + 1) for epoch in pbar: loss, accuracy = self.train_step(train_data) train_data.on_epoch_end() history.add_results(loss, 'loss') history.add_results(accuracy, 'acc') if validation: val_loss, val_accuracy = self.test_step(val_data) val_data.on_epoch_end() history.add_results(val_loss, 'val_loss') history.add_results(val_accuracy, 'val_acc') # record eoch and running times history.record_epoch(epoch) if save_best and history.save_best: self.save(weight_path, as_model=as_model) # early stopping if early_stopping and history.time_to_early_stopping( early_stopping): msg = f'Early stopping with patience {early_stopping}.' if verbose: pbar.set_description(msg) pbar.close() break if verbose: msg = f'loss {loss:.2f}, acc {accuracy:.2%}' if validation: msg += f', val_loss {val_loss:.2f}, val_acc {val_accuracy:.2%}' pbar.set_description(msg) if save_best: self.load(weight_path, as_model=as_model) if self.kind == "T": remove_tf_weights(weight_path) else: remove_torch_weights(weight_path) return history