def build(self, logger, **kwargs): self.transform.build_config() self.model = self.build_model( **merge_dict(self.config, training=kwargs.get('training', None), loss=kwargs.get('loss', None))) self.transform.lock_vocabs() optimizer = self.build_optimizer(**self.config) loss = self.build_loss( **self.config if 'loss' in self. config else dict(list(self.config.items()) + [('loss', None)])) # allow for different metrics = self.build_metrics( **merge_dict(self.config, metrics=kwargs.get('metrics', 'accuracy'), logger=logger, overwrite=True)) if not isinstance(metrics, list): if isinstance(metrics, tf.keras.metrics.Metric): metrics = [metrics] if not self.model.built: sample_inputs = self.sample_data if sample_inputs is not None: self.model(sample_inputs) else: if len(self.transform.output_shapes[0] ) == 1 and self.transform.output_shapes[0][0] is None: x_shape = self.transform.output_shapes[0] else: x_shape = list(self.transform.output_shapes[0]) for i, shape in enumerate(x_shape): x_shape[i] = [None] + shape # batch + X.shape self.model.build(input_shape=x_shape) self.compile_model(optimizer, loss, metrics) return self.model, optimizer, loss, metrics
def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, run_eagerly=False, logger=None, verbose=True, **kwargs): self._capture_config(locals()) self.transform = self.build_transform(**self.config) if not save_dir: save_dir = tempdir_human() if not logger: logger = init_logger(name='train', root_dir=save_dir, level=logging.INFO if verbose else logging.WARN) logger.info('Hyperparameter:\n' + self.config.to_json()) num_examples = self.build_vocab(trn_data, logger) # assert num_examples, 'You forgot to return the number of training examples in your build_vocab' logger.info('Building...') train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True)) logger.info('Model built:\n' + summary_of_model(self.model)) self.save_config(save_dir) self.save_vocabs(save_dir) self.save_meta(save_dir) trn_data = self.build_train_dataset(trn_data, batch_size, num_examples) dev_data = self.build_valid_dataset(dev_data, batch_size) callbacks = self.build_callbacks(save_dir, logger, **self.config) # need to know #batches, otherwise progbar crashes dev_steps = math.ceil(size_of_dataset(dev_data) / batch_size) checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint) timer = Timer() try: history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs, num_examples=num_examples, train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps, callbacks=callbacks, logger=logger, model=model, optimizer=optimizer, loss=loss, metrics=metrics, overwrite=True)) except KeyboardInterrupt: print() if not checkpoint or checkpoint.best in (np.Inf, -np.Inf): self.save_weights(save_dir) logger.info('Aborted with model saved') else: logger.info(f'Aborted with model saved with best {checkpoint.monitor} = {checkpoint.best:.4f}') # noinspection PyTypeChecker history: tf.keras.callbacks.History() = get_callback_by_class(callbacks, tf.keras.callbacks.History) delta_time = timer.stop() best_epoch_ago = 0 if history and hasattr(history, 'epoch'): trained_epoch = len(history.epoch) logger.info('Trained {} epochs in {}, each epoch takes {}'. format(trained_epoch, delta_time, delta_time / trained_epoch if trained_epoch else delta_time)) io_util.save_json(history.history, io_util.path_join(save_dir, 'history.json'), cls=io_util.NumpyEncoder) monitor_history: List = history.history.get(checkpoint.monitor, None) if monitor_history: best_epoch_ago = len(monitor_history) - monitor_history.index(checkpoint.best) if checkpoint and monitor_history and checkpoint.best != monitor_history[-1]: logger.info(f'Restored the best model saved with best ' f'{checkpoint.monitor} = {checkpoint.best:.4f} ' f'saved {best_epoch_ago} epochs ago') self.load_weights(save_dir) # restore best model return history
def load(self, save_dir: str, logger=hanlp.utils.log_util.logger, **kwargs): self.meta['load_path'] = save_dir save_dir = get_resource(save_dir) self.load_config(save_dir) self.load_vocabs(save_dir) self.build(**merge_dict(self.config, training=False, logger=logger, **kwargs, overwrite=True, inplace=True)) self.load_weights(save_dir, **kwargs) self.load_meta(save_dir)
def build_callbacks(self, save_dir, logger, metrics, **kwargs): callbacks = super().build_callbacks(save_dir, **merge_dict(self.config, overwrite=True, logger=logger, metrics=metrics, **kwargs)) if isinstance(metrics, tuple): metrics = list(metrics) callbacks.append(self.build_progbar(metrics)) params = {'verbose': 1, 'epochs': self.config.epochs} for c in callbacks: c.set_params(params) c.set_model(self.model) return callbacks