def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, run_eagerly=False, logger=None, verbose=True, **kwargs): self._capture_config(locals()) self.transform = self.build_transform(**self.config) if not save_dir: save_dir = tempdir_human() if not logger: logger = init_logger(name='train', root_dir=save_dir, level=logging.INFO if verbose else logging.WARN) logger.info('Hyperparameter:\n' + self.config.to_json()) num_examples = self.build_vocab(trn_data, logger) # assert num_examples, 'You forgot to return the number of training examples in your build_vocab' logger.info('Building...') train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True)) logger.info('Model built:\n' + summary_of_model(self.model)) self.save_config(save_dir) self.save_vocabs(save_dir) self.save_meta(save_dir) trn_data = self.build_train_dataset(trn_data, batch_size, num_examples) dev_data = self.build_valid_dataset(dev_data, batch_size) callbacks = self.build_callbacks(save_dir, logger, **self.config) # need to know #batches, otherwise progbar crashes dev_steps = math.ceil(size_of_dataset(dev_data) / batch_size) checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint) timer = Timer() try: history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs, num_examples=num_examples, train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps, callbacks=callbacks, logger=logger, model=model, optimizer=optimizer, loss=loss, metrics=metrics, overwrite=True)) except KeyboardInterrupt: print() if not checkpoint or checkpoint.best in (np.Inf, -np.Inf): self.save_weights(save_dir) logger.info('Aborted with model saved') else: logger.info(f'Aborted with model saved with best {checkpoint.monitor} = {checkpoint.best:.4f}') # noinspection PyTypeChecker history: tf.keras.callbacks.History() = get_callback_by_class(callbacks, tf.keras.callbacks.History) delta_time = timer.stop() best_epoch_ago = 0 if history and hasattr(history, 'epoch'): trained_epoch = len(history.epoch) logger.info('Trained {} epochs in {}, each epoch takes {}'. format(trained_epoch, delta_time, delta_time / trained_epoch if trained_epoch else delta_time)) io_util.save_json(history.history, io_util.path_join(save_dir, 'history.json'), cls=io_util.NumpyEncoder) monitor_history: List = history.history.get(checkpoint.monitor, None) if monitor_history: best_epoch_ago = len(monitor_history) - monitor_history.index(checkpoint.best) if checkpoint and monitor_history and checkpoint.best != monitor_history[-1]: logger.info(f'Restored the best model saved with best ' f'{checkpoint.monitor} = {checkpoint.best:.4f} ' f'saved {best_epoch_ago} epochs ago') self.load_weights(save_dir) # restore best model return history
def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128, logger: logging.Logger = None, callbacks: List[tf.keras.callbacks.Callback] = None, warm_up=True, verbose=True, **kwargs): input_path = get_resource(input_path) file_prefix, ext = os.path.splitext(input_path) name = os.path.basename(file_prefix) if not name: name = 'evaluate' if save_dir and not logger: logger = init_logger(name=name, root_dir=save_dir, level=logging.INFO if verbose else logging.WARN, mode='w') tst_data = self.transform.file_to_dataset(input_path, batch_size=batch_size) samples = self.num_samples_in(tst_data) num_batches = math.ceil(samples / batch_size) if warm_up: for x, y in tst_data: self.model.predict_on_batch(x) break if output: assert save_dir, 'Must pass save_dir in order to output' if isinstance(output, bool): output = os.path.join(save_dir, name) + '.predict' + ext elif isinstance(output, str): output = output else: raise RuntimeError('output ({}) must be of type bool or str'.format(repr(output))) timer = Timer() eval_outputs = self.evaluate_dataset(tst_data, callbacks, output, num_batches, **kwargs) loss, score, output = eval_outputs[0], eval_outputs[1], eval_outputs[2] delta_time = timer.stop() speed = samples / delta_time.delta_seconds if logger: f1: IOBES_F1_TF = None for metric in self.model.metrics: if isinstance(metric, IOBES_F1_TF): f1 = metric break extra_report = '' if f1: overall, by_type, extra_report = f1.state.result(full=True, verbose=False) extra_report = ' \n' + extra_report logger.info('Evaluation results for {} - ' 'loss: {:.4f} - {} - speed: {:.2f} sample/sec{}' .format(name + ext, loss, format_scores(score) if isinstance(score, dict) else format_metrics(self.model.metrics), speed, extra_report)) if output: logger.info('Saving output to {}'.format(output)) with open(output, 'w', encoding='utf-8') as out: self.evaluate_output(tst_data, out, num_batches, self.model.metrics) return loss, score, speed