Beispiel #1
0
    def prepare_callback(self, callback_list, valid_data=None, valid_labels=None,
                         checkpoint_dir=None, model_name=None, swa_model=None):
        """

        Args:
            callback_list: list of str, each item indicate the callback to apply during training.
                       For example, 'earlystopping' means using 'EarlyStopping' callback.
            valid_data: list of tokenized (in char level) texts for evaluation
            valid_labels: labels string of valid data
            checkpoint_dir: str, directory to save spm model, must be provided when using
                            `ModelCheckpoint` or `SWA` callback.
            model_name: str, prefix of spm model's weights file must be provided when using
                        `ModelCheckpoint` or `SWA` callback.
                        For example, if checkpoint_dir is 'ckpt' and model_name is 'model', the
                        weights of spm model saved by `ModelCheckpoint` callback will be
                        'ckpt/model.hdf5' and by `SWA` callback will be 'ckpt/model_swa.hdf5'

        Returns: a list of `keras.callbacks.Callback` instances

        """
        assert not isinstance(callback_list, str)
        callback_list = callback_list or []
        callbacks = []
        if valid_data is not None and valid_labels is not None:
            callbacks.append(SPMMetric(self.preprocessor, valid_data, valid_labels))
            add_metric = True
        else:
            add_metric = False

        if 'modelcheckpoint' in callback_list:
            if not add_metric:
                logging.warning('Using `ModelCheckpoint` with validation data not provided is not '
                                'Recommended! We will use `loss` (of training data) as monitor.')

            assert checkpoint_dir is not None, \
                '"checkpoint_dir" must must be provided when using "ModelCheckpoint" callback'
            assert model_name is not None, \
                '"model_name" must must be provided when using "ModelCheckpoint" callback'
            callbacks.append(ModelCheckpoint(filepath=os.path.join(checkpoint_dir,
                                                                   f'{model_name}.hdf5'),
                                             monitor='val_f1' if add_metric else 'loss',
                                             save_best_only=True,
                                             save_weights_only=True,
                                             mode='max' if add_metric else 'min',
                                             verbose=1))
            logging.info('ModelCheckpoint Callback added')

        if 'earlystopping' in callback_list:
            if not add_metric:
                logging.warning('Using `Earlystopping` with validation data not provided is not '
                                'Recommended! We will use `loss` (of training data) as monitor.')
            callbacks.append(EarlyStopping(monitor='val_f1' if add_metric else 'loss',
                                           mode='max' if add_metric else 'min',
                                           patience=5,
                                           verbose=1))
            logging.info('Earlystopping Callback added')

        if 'swa' in callback_list:
            assert checkpoint_dir is not None, \
                '"checkpoint_dir" must must be provided when using "SWA" callback'
            assert model_name is not None, \
                '"model_name" must must be provided when using "SWA" callback'
            callbacks.append(SWA(swa_model=swa_model, checkpoint_dir=checkpoint_dir,
                                 model_name=model_name, swa_start=5))
            logging.info('SWA Callback added')

        return callbacks
Beispiel #2
0
    def prepare_callback(self,
                         callback_list: List[str],
                         valid_data: Optional[List[List[str]]] = None,
                         valid_labels: Optional[List[List[str]]] = None,
                         checkpoint_dir: Optional[str] = None,
                         model_name: Optional[str] = None,
                         swa_model: Optional[tf.keras.models.Model] = None) \
            -> List[tf.keras.callbacks.Callback]:
        """Prepare the callbacks to be applied during training.

        Args:
            callback_list: List of str or instance of `keras.callbacks.Callback`. Each item
                indicates the callback to be applied during training. Currently, we support using
                'modelcheckpoint' for `ModelCheckpoint` callback, 'earlystopping` for
                'Earlystopping` callback, 'swa' for 'SWA' callback.
            valid_data: Optional List of List of str, can be None. List of tokenized (in char
                level) texts for evaluation, like ``[['我', '在', '上', '海', '上', '学'], ...]``.
            valid_labels: Optional List of List of str, can be None. The labels of valid_data,
                usually in BIO or BIOES format, like
                ``[['O', 'O', 'B-LOC', 'I-LOC', 'O', 'O'], ...]``.
                When valid_data and valid_labels are both provided, we will automatically add
                `NERMetric` callback for evaluation during training.
            checkpoint_dir: Optional str, can be None. Directory to save the ner model. It must be
                provided when using `ModelCheckpoint` or `SWA` callback, since these callbacks needs
                to save ner model after training.
            model_name: Optional str, can be None. Prefix of ner model's weights file. I must be
                provided when using `ModelCheckpoint` or `SWA` callback, since these callbacks needs
                to save ner model after training. For example, if checkpoint_dir is 'ckpt' and
                model_name is 'model', the weights of ner model saved by `ModelCheckpoint` callback
                will be 'ckpt/model.hdf5' and by `SWA` callback will be 'ckpt/model_swa.hdf5'.
            swa_model: Instance of `tf.keras.model.Model`. The ner model which is used in `SWA`
                callback to keep track of weight averaging during training. It has the same
                architecture as self.model. Only pass it when using `SWA` callback.

        Returns: List of `keras.callbacks.Callback` instances

        """
        assert not isinstance(callback_list, str)
        callback_list = callback_list or []
        callbacks = []
        if valid_data is not None and valid_labels is not None:
            callbacks.append(
                NERMetric(self.preprocessor, valid_data, valid_labels))
            add_metric = True
        else:
            add_metric = False

        if 'modelcheckpoint' in callback_list:
            if not add_metric:
                logging.warning(
                    'Using `ModelCheckpoint` without validation data provided is not Recommended! '
                    'We will use `loss` (of training data) as monitor.')

            assert checkpoint_dir is not None, \
                '`checkpoint_dir` must must be provided when using "ModelCheckpoint" callback'
            assert model_name is not None, \
                '`model_name` must must be provided when using "ModelCheckpoint" callback'
            callbacks.append(
                tf.keras.callbacks.ModelCheckpoint(
                    filepath=os.path.join(checkpoint_dir,
                                          f'{model_name}.hdf5'),
                    monitor='val_f1' if add_metric else 'loss',
                    save_best_only=True,
                    save_weights_only=True,
                    mode='max' if add_metric else 'min',
                    verbose=1))
            logging.info('ModelCheckpoint Callback added')

        if 'earlystopping' in callback_list:
            if not add_metric:
                logging.warning(
                    'Using `Earlystopping` with validation data not provided is not '
                    'Recommended! We will use `loss` (of training data) as monitor.'
                )
            callbacks.append(
                tf.keras.callbacks.EarlyStopping(
                    monitor='val_f1' if add_metric else 'loss',
                    mode='max' if add_metric else 'min',
                    patience=5,
                    verbose=1))
            logging.info('Earlystopping Callback added')

        if 'swa' in callback_list:
            assert checkpoint_dir is not None, \
                '`checkpoint_dir` must must be provided when using "SWA" callback'
            assert model_name is not None, \
                '`model_name` must must be provided when using "SWA" callback'
            assert swa_model is not None, \
                '`swa_model` must must be provided when using "SWA" callback'
            callbacks.append(
                SWA(swa_model=swa_model,
                    checkpoint_dir=checkpoint_dir,
                    model_name=model_name,
                    swa_start=5))
            logging.info('SWA Callback added')

        return callbacks