def prepare_callback(self, callback_list, valid_data=None, valid_labels=None, checkpoint_dir=None, model_name=None, swa_model=None): """ Args: callback_list: list of str, each item indicate the callback to apply during training. For example, 'earlystopping' means using 'EarlyStopping' callback. valid_data: list of tokenized (in char level) texts for evaluation valid_labels: labels string of valid data checkpoint_dir: str, directory to save spm model, must be provided when using `ModelCheckpoint` or `SWA` callback. model_name: str, prefix of spm model's weights file must be provided when using `ModelCheckpoint` or `SWA` callback. For example, if checkpoint_dir is 'ckpt' and model_name is 'model', the weights of spm model saved by `ModelCheckpoint` callback will be 'ckpt/model.hdf5' and by `SWA` callback will be 'ckpt/model_swa.hdf5' Returns: a list of `keras.callbacks.Callback` instances """ assert not isinstance(callback_list, str) callback_list = callback_list or [] callbacks = [] if valid_data is not None and valid_labels is not None: callbacks.append(SPMMetric(self.preprocessor, valid_data, valid_labels)) add_metric = True else: add_metric = False if 'modelcheckpoint' in callback_list: if not add_metric: logging.warning('Using `ModelCheckpoint` with validation data not provided is not ' 'Recommended! We will use `loss` (of training data) as monitor.') assert checkpoint_dir is not None, \ '"checkpoint_dir" must must be provided when using "ModelCheckpoint" callback' assert model_name is not None, \ '"model_name" must must be provided when using "ModelCheckpoint" callback' callbacks.append(ModelCheckpoint(filepath=os.path.join(checkpoint_dir, f'{model_name}.hdf5'), monitor='val_f1' if add_metric else 'loss', save_best_only=True, save_weights_only=True, mode='max' if add_metric else 'min', verbose=1)) logging.info('ModelCheckpoint Callback added') if 'earlystopping' in callback_list: if not add_metric: logging.warning('Using `Earlystopping` with validation data not provided is not ' 'Recommended! We will use `loss` (of training data) as monitor.') callbacks.append(EarlyStopping(monitor='val_f1' if add_metric else 'loss', mode='max' if add_metric else 'min', patience=5, verbose=1)) logging.info('Earlystopping Callback added') if 'swa' in callback_list: assert checkpoint_dir is not None, \ '"checkpoint_dir" must must be provided when using "SWA" callback' assert model_name is not None, \ '"model_name" must must be provided when using "SWA" callback' callbacks.append(SWA(swa_model=swa_model, checkpoint_dir=checkpoint_dir, model_name=model_name, swa_start=5)) logging.info('SWA Callback added') return callbacks
def prepare_callback(self, callback_list: List[str], valid_data: Optional[List[List[str]]] = None, valid_labels: Optional[List[List[str]]] = None, checkpoint_dir: Optional[str] = None, model_name: Optional[str] = None, swa_model: Optional[tf.keras.models.Model] = None) \ -> List[tf.keras.callbacks.Callback]: """Prepare the callbacks to be applied during training. Args: callback_list: List of str or instance of `keras.callbacks.Callback`. Each item indicates the callback to be applied during training. Currently, we support using 'modelcheckpoint' for `ModelCheckpoint` callback, 'earlystopping` for 'Earlystopping` callback, 'swa' for 'SWA' callback. valid_data: Optional List of List of str, can be None. List of tokenized (in char level) texts for evaluation, like ``[['我', '在', '上', '海', '上', '学'], ...]``. valid_labels: Optional List of List of str, can be None. The labels of valid_data, usually in BIO or BIOES format, like ``[['O', 'O', 'B-LOC', 'I-LOC', 'O', 'O'], ...]``. When valid_data and valid_labels are both provided, we will automatically add `NERMetric` callback for evaluation during training. checkpoint_dir: Optional str, can be None. Directory to save the ner model. It must be provided when using `ModelCheckpoint` or `SWA` callback, since these callbacks needs to save ner model after training. model_name: Optional str, can be None. Prefix of ner model's weights file. I must be provided when using `ModelCheckpoint` or `SWA` callback, since these callbacks needs to save ner model after training. For example, if checkpoint_dir is 'ckpt' and model_name is 'model', the weights of ner model saved by `ModelCheckpoint` callback will be 'ckpt/model.hdf5' and by `SWA` callback will be 'ckpt/model_swa.hdf5'. swa_model: Instance of `tf.keras.model.Model`. The ner model which is used in `SWA` callback to keep track of weight averaging during training. It has the same architecture as self.model. Only pass it when using `SWA` callback. Returns: List of `keras.callbacks.Callback` instances """ assert not isinstance(callback_list, str) callback_list = callback_list or [] callbacks = [] if valid_data is not None and valid_labels is not None: callbacks.append( NERMetric(self.preprocessor, valid_data, valid_labels)) add_metric = True else: add_metric = False if 'modelcheckpoint' in callback_list: if not add_metric: logging.warning( 'Using `ModelCheckpoint` without validation data provided is not Recommended! ' 'We will use `loss` (of training data) as monitor.') assert checkpoint_dir is not None, \ '`checkpoint_dir` must must be provided when using "ModelCheckpoint" callback' assert model_name is not None, \ '`model_name` must must be provided when using "ModelCheckpoint" callback' callbacks.append( tf.keras.callbacks.ModelCheckpoint( filepath=os.path.join(checkpoint_dir, f'{model_name}.hdf5'), monitor='val_f1' if add_metric else 'loss', save_best_only=True, save_weights_only=True, mode='max' if add_metric else 'min', verbose=1)) logging.info('ModelCheckpoint Callback added') if 'earlystopping' in callback_list: if not add_metric: logging.warning( 'Using `Earlystopping` with validation data not provided is not ' 'Recommended! We will use `loss` (of training data) as monitor.' ) callbacks.append( tf.keras.callbacks.EarlyStopping( monitor='val_f1' if add_metric else 'loss', mode='max' if add_metric else 'min', patience=5, verbose=1)) logging.info('Earlystopping Callback added') if 'swa' in callback_list: assert checkpoint_dir is not None, \ '`checkpoint_dir` must must be provided when using "SWA" callback' assert model_name is not None, \ '`model_name` must must be provided when using "SWA" callback' assert swa_model is not None, \ '`swa_model` must must be provided when using "SWA" callback' callbacks.append( SWA(swa_model=swa_model, checkpoint_dir=checkpoint_dir, model_name=model_name, swa_start=5)) logging.info('SWA Callback added') return callbacks