示例#1
0
 def build(self, logger, **kwargs):
     self.transform.build_config()
     self.model = self.build_model(
         **merge_dict(self.config,
                      training=kwargs.get('training', None),
                      loss=kwargs.get('loss', None)))
     self.transform.lock_vocabs()
     optimizer = self.build_optimizer(**self.config)
     loss = self.build_loss(
         **self.config if 'loss' in self.
         config else dict(list(self.config.items()) + [('loss', None)]))
     # allow for different
     metrics = self.build_metrics(
         **merge_dict(self.config,
                      metrics=kwargs.get('metrics', 'accuracy'),
                      logger=logger,
                      overwrite=True))
     if not isinstance(metrics, list):
         if isinstance(metrics, tf.keras.metrics.Metric):
             metrics = [metrics]
     if not self.model.built:
         sample_inputs = self.sample_data
         if sample_inputs is not None:
             self.model(sample_inputs)
         else:
             if len(self.transform.output_shapes[0]
                    ) == 1 and self.transform.output_shapes[0][0] is None:
                 x_shape = self.transform.output_shapes[0]
             else:
                 x_shape = list(self.transform.output_shapes[0])
                 for i, shape in enumerate(x_shape):
                     x_shape[i] = [None] + shape  # batch + X.shape
             self.model.build(input_shape=x_shape)
     self.compile_model(optimizer, loss, metrics)
     return self.model, optimizer, loss, metrics
示例#2
0
 def fit(self, trn_data, dev_data, save_dir, batch_size, epochs, run_eagerly=False, logger=None, verbose=True,
         **kwargs):
     self._capture_config(locals())
     self.transform = self.build_transform(**self.config)
     if not save_dir:
         save_dir = tempdir_human()
     if not logger:
         logger = init_logger(name='train', root_dir=save_dir, level=logging.INFO if verbose else logging.WARN)
     logger.info('Hyperparameter:\n' + self.config.to_json())
     num_examples = self.build_vocab(trn_data, logger)
     # assert num_examples, 'You forgot to return the number of training examples in your build_vocab'
     logger.info('Building...')
     train_steps_per_epoch = math.ceil(num_examples / batch_size) if num_examples else None
     self.config.train_steps = train_steps_per_epoch * epochs if num_examples else None
     model, optimizer, loss, metrics = self.build(**merge_dict(self.config, logger=logger, training=True))
     logger.info('Model built:\n' + summary_of_model(self.model))
     self.save_config(save_dir)
     self.save_vocabs(save_dir)
     self.save_meta(save_dir)
     trn_data = self.build_train_dataset(trn_data, batch_size, num_examples)
     dev_data = self.build_valid_dataset(dev_data, batch_size)
     callbacks = self.build_callbacks(save_dir, logger, **self.config)
     # need to know #batches, otherwise progbar crashes
     dev_steps = math.ceil(size_of_dataset(dev_data) / batch_size)
     checkpoint = get_callback_by_class(callbacks, tf.keras.callbacks.ModelCheckpoint)
     timer = Timer()
     try:
         history = self.train_loop(**merge_dict(self.config, trn_data=trn_data, dev_data=dev_data, epochs=epochs,
                                                num_examples=num_examples,
                                                train_steps_per_epoch=train_steps_per_epoch, dev_steps=dev_steps,
                                                callbacks=callbacks, logger=logger, model=model, optimizer=optimizer,
                                                loss=loss,
                                                metrics=metrics, overwrite=True))
     except KeyboardInterrupt:
         print()
         if not checkpoint or checkpoint.best in (np.Inf, -np.Inf):
             self.save_weights(save_dir)
             logger.info('Aborted with model saved')
         else:
             logger.info(f'Aborted with model saved with best {checkpoint.monitor} = {checkpoint.best:.4f}')
         # noinspection PyTypeChecker
         history: tf.keras.callbacks.History() = get_callback_by_class(callbacks, tf.keras.callbacks.History)
     delta_time = timer.stop()
     best_epoch_ago = 0
     if history and hasattr(history, 'epoch'):
         trained_epoch = len(history.epoch)
         logger.info('Trained {} epochs in {}, each epoch takes {}'.
                     format(trained_epoch, delta_time, delta_time / trained_epoch if trained_epoch else delta_time))
         io_util.save_json(history.history, io_util.path_join(save_dir, 'history.json'), cls=io_util.NumpyEncoder)
         monitor_history: List = history.history.get(checkpoint.monitor, None)
         if monitor_history:
             best_epoch_ago = len(monitor_history) - monitor_history.index(checkpoint.best)
         if checkpoint and monitor_history and checkpoint.best != monitor_history[-1]:
             logger.info(f'Restored the best model saved with best '
                         f'{checkpoint.monitor} = {checkpoint.best:.4f} '
                         f'saved {best_epoch_ago} epochs ago')
             self.load_weights(save_dir)  # restore best model
     return history
示例#3
0
 def load(self, save_dir: str, logger=hanlp.utils.log_util.logger, **kwargs):
     self.meta['load_path'] = save_dir
     save_dir = get_resource(save_dir)
     self.load_config(save_dir)
     self.load_vocabs(save_dir)
     self.build(**merge_dict(self.config, training=False, logger=logger, **kwargs, overwrite=True, inplace=True))
     self.load_weights(save_dir, **kwargs)
     self.load_meta(save_dir)
示例#4
0
 def build_callbacks(self, save_dir, logger, metrics, **kwargs):
     callbacks = super().build_callbacks(save_dir,
                                         **merge_dict(self.config, overwrite=True, logger=logger, metrics=metrics,
                                                      **kwargs))
     if isinstance(metrics, tuple):
         metrics = list(metrics)
     callbacks.append(self.build_progbar(metrics))
     params = {'verbose': 1, 'epochs': self.config.epochs}
     for c in callbacks:
         c.set_params(params)
         c.set_model(self.model)
     return callbacks