Exemple #1
0
class text_classifier(object):
    def __init__(self,
                 bert_config,
                 pretrain_model,
                 batch_size,
                 seq_length,
                 optimizer,
                 num_classes,
                 use_token_type=True,
                 mask=True,
                 max_predictions_per_seq=20,
                 multi_gpu=None):
        if not isinstance(bert_config, BertConfig):
            raise ValueError(
                "`bert_config` must be a instance of `BertConfig`")
        if multi_gpu:
            if not tf.test.is_gpu_available:
                raise ValueError("GPU is not available.")

        self.config = bert_config
        self.batch_size = batch_size
        self.seq_length = seq_length
        self.use_token_type = use_token_type
        self.max_predictions_per_seq = max_predictions_per_seq
        self.mask = mask
        self.num_classes = num_classes

        if multi_gpu:
            with tf.device('/cpu:0'):
                model = self._build_model(pretrain_model)
                model.compile(optimizer=optimizer,
                              loss=losses.categorical_crossentropy)
            parallel_model = multi_gpu_model(model=model, gpus=multi_gpu)
            parallel_model.compile(optimizer=optimizer,
                                   loss=losses.categorical_crossentropy)
        else:
            model = self._build_model(pretrain_model)
            model.compile(optimizer=optimizer,
                          loss=losses.categorical_crossentropy)

        self.estimator = model
        if multi_gpu:
            self.estimator = model

    def fit(self,
            x,
            y,
            epochs,
            shuffle=True,
            callbacks=None,
            validation_split=0.,
            validation_data=None,
            class_weight=None,
            sample_weight=None,
            **kwargs):
        self.estimator.fit(x=x,
                           y=y,
                           batch_size=self.batch_size,
                           epochs=epochs,
                           shuffle=shuffle,
                           callbacks=callbacks,
                           validation_split=validation_split,
                           validation_data=validation_data,
                           class_weight=class_weight,
                           sample_weight=sample_weight,
                           **kwargs)

    def predict(self, x, batch_size=None, verbose=0, steps=None):
        result = self.estimator.predict(x=x,
                                        batch_size=batch_size,
                                        verbose=verbose,
                                        steps=steps)
        return result

    def _build_model(self, pretrain_model):
        input_ids = Input(shape=(self.seq_length, ))
        input_mask = Input(shape=(self.seq_length, ))
        inputs = [input_ids, input_mask]
        if self.use_token_type:
            input_token_type_ids = Input(shape=(self.seq_length, ))
            inputs.append(input_token_type_ids)

        self.bert = BertModel(
            self.config,
            batch_size=self.batch_size,
            seq_length=self.seq_length,
            max_predictions_per_seq=self.max_predictions_per_seq,
            use_token_type=self.use_token_type,
            mask=self.mask)
        self.bert_encoder = self.bert.get_bert_encoder()
        self.bert_encoder.load_weights(pretrain_model)
        pooled_output = self.bert_encoder(inputs)
        pooled_output = Dropout(self.config.hidden_dropout_prob)(pooled_output)
        pred = Dense(units=self.num_classes,
                     activation='softmax',
                     kernel_initializer=initializers.truncated_normal(
                         stddev=self.config.initializer_range))(pooled_output)
        model = Model(inputs=inputs, outputs=pred)
        return model
def bert_pretraining(train_data_path,
                     bert_config_file,
                     save_path,
                     batch_size=32,
                     epochs=2,
                     seq_length=128,
                     max_predictions_per_seq=20,
                     lr=5e-5,
                     num_warmup_steps=10000,
                     checkpoints_interval_steps=1000,
                     weight_decay_rate=0.01,
                     validation_ratio=0.1,
                     max_num_val=10000,
                     multi_gpu=0,
                     val_batch_size=None,
                     pretraining_model_name='bert_pretraining.h5',
                     encoder_model_name='bert_encoder.h5',
                     random_state=None):
    '''masked LM/next sentence masked_lm pre-training for BERT.

    # Args
        train_data_path: path of train data.
        bert_config_file: The config json file corresponding to the pre-trained BERT model.
            This specifies the model architecture.
        save_path: dir to save checkpoints.
        batch_size: Integer.  Number of samples per gradient update.
        epochs: Integer. Number of epochs to train the model.
        seq_length: The maximum total input sequence length after tokenization.
            Sequences longer than this will be truncated, and sequences shorter
            than this will be padded. Must match data generation.
        max_predictions_per_seq:Integer. Maximum number of masked LM predictions per sequence.
        lr: float >= 0. Learning rate.
        num_warmup_steps: Integer. Number of warm up steps.
        checkpoints_interval_steps: Integer. interval of checkpoints. only enable after model is warmed up.
        weight_decay_rate: float. value of weight decay rate.
        validation_ratio:  Float between 0 and 1.
            Fraction of the training data to be used as validation data.
            The model will set apart this fraction of the training data,
            will not train on it, and will evaluate
            the loss and any model metrics.
        max_num_val: Integer. max number of validation data.
            when multi_gpu > 0, model will use cpu to evaluate validation data.
            Controlling the argument can benefit the efficiency of validation process.
        multi_gpu: Integer. when multi_gpu > 0, cpu will be use to merge model trained in gpus.
        val_batch_size: Integer.  Number of samples used in validation step.
            If `val_batch_size` is None, val_batch_size will be equal to `batch_size`.
        pretraining_model_name: name of pretraining model file.
        encoder_model_name: name of encoder model file.
        random_state : int, RandomState instance or None, optional (default=None)
            If int, random_state is the seed used by the random number generator;
            If RandomState instance, random_state is the random number generator;
            If None, the random number generator is the RandomState instance used
            by `np.random`.
    '''
    if multi_gpu > 0:
        if not tf.test.is_gpu_available:
            raise ValueError("GPU is not available. Set `multi_gpu` to be 0.")
    pre_training_data = np.load(train_data_path)
    tokens_ids = pre_training_data['tokens_ids']
    tokens_mask = pre_training_data['tokens_mask']
    segment_ids = pre_training_data['segment_ids']
    is_random_next = pre_training_data['is_random_next']
    masked_lm_positions = pre_training_data['masked_lm_positions']
    masked_lm_label = pre_training_data['masked_lm_labels']

    num_train_samples = int(len(tokens_ids) * (1 - validation_ratio))
    num_train_steps = int(np.ceil(num_train_samples / batch_size)) * epochs

    logging.info('train steps: {}'.format(num_train_steps))
    logging.info('train samples: {}'.format(tokens_ids))
    if num_train_steps < num_warmup_steps + checkpoints_interval_steps:
        raise ValueError(
            "number of train steps must be larger than the sum of"
            " `num_warmup_steps` and `checkpoints_interval_steps`."
            "enlarge your train data or reduce batch_size")
    warmup_ratio = num_warmup_steps / num_train_steps
    if warmup_ratio > 0.02:
        warnings.warn(
            "model performance may be suitable when warmup steps is 0.01~0.02 of train steps.",
            UserWarning)

    config = BertConfig.from_json_file(bert_config_file)

    num_val = int(len(tokens_ids) * validation_ratio)
    if num_val > max_num_val:
        validation_ratio = max_num_val / len(tokens_ids)
    # split data for train and valid
    sss = StratifiedShuffleSplit(n_splits=1,
                                 test_size=validation_ratio,
                                 random_state=random_state)
    for train_index, test_index in sss.split(tokens_ids, is_random_next):
        train_tokens_ids, test_tokens_ids = tokens_ids[
            train_index], tokens_ids[test_index]
        train_tokens_mask, test_tokens_mask = tokens_mask[
            train_index], tokens_mask[test_index]
        train_segment_ids, test_segment_ids = segment_ids[
            train_index], segment_ids[test_index]
        train_is_random_next, test_is_random_next = is_random_next[
            train_index], is_random_next[test_index]
        train_masked_lm_positions, test_masked_lm_positions = masked_lm_positions[
            train_index], masked_lm_positions[test_index]
        train_masked_lm_label, test_masked_lm_label = masked_lm_label[
            train_index], masked_lm_label[test_index]
        test_masked_lm_label = test_masked_lm_label.reshape((-1))
        test_masked_lm_label = to_categorical(test_masked_lm_label,
                                              num_classes=config.vocab_size)
        test_masked_lm_label = test_masked_lm_label.reshape(
            (-1, max_predictions_per_seq, config.vocab_size))

    logging.info("build pretraining nnet...")
    adam = AdamWeightDecayOpt(
        lr=lr,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-6,
        weight_decay_rate=weight_decay_rate,
        exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
    bert = BertModel(config,
                     batch_size=batch_size,
                     seq_length=seq_length,
                     max_predictions_per_seq=max_predictions_per_seq,
                     use_token_type=True,
                     embeddings_matrix=None,
                     mask=True)
    if multi_gpu:
        # To avoid OOM errors, this model could have been built on CPU
        with tf.device('/cpu:0'):
            pretraining_model = bert.get_pretraining_model()
            pretraining_model.compile(optimizer=adam,
                                      loss=losses.categorical_crossentropy,
                                      metrics=['acc'],
                                      loss_weights=[0.5, 0.5])
        parallel_pretraining_model = multi_gpu_model(model=pretraining_model,
                                                     gpus=multi_gpu)
        parallel_pretraining_model.compile(
            optimizer=adam,
            loss=losses.categorical_crossentropy,
            metrics=['acc'],
            loss_weights=[0.5, 0.5])
    else:
        pretraining_model = bert.get_pretraining_model()
        pretraining_model.compile(optimizer=adam,
                                  loss=losses.categorical_crossentropy,
                                  metrics=['acc'],
                                  loss_weights=[0.5, 0.5])

    logging.info('training pretraining nnet for {} epochs'.format(epochs))
    train_sample_generator = SampleSequence(
        x=[
            train_tokens_ids, train_tokens_mask, train_segment_ids,
            train_masked_lm_positions
        ],
        y=[train_masked_lm_label, train_is_random_next],
        batch_size=batch_size,
        vocab_size=config.vocab_size,
        max_predictions_per_seq=max_predictions_per_seq)

    checkpoint_model = None
    if multi_gpu:
        checkpoint_model = pretraining_model

    if not os.path.exists(save_path):
        os.mkdir(save_path)

    checkpoint = StepPreTrainModelCheckpoint(
        filepath="%s/%s" % (save_path, pretraining_model_name),
        start_step=num_warmup_steps,
        period=checkpoints_interval_steps,
        save_best_only=True,
        verbose=1,
        val_batch_size=val_batch_size,
        model=
        checkpoint_model  #when use multi_gpu_model, set model to the original model
    )

    estimator = pretraining_model
    if multi_gpu:
        estimator = parallel_pretraining_model

    estimator.fit_generator(
        generator=train_sample_generator,
        epochs=epochs,
        callbacks=[checkpoint],
        shuffle=False,
        validation_data=([
            test_tokens_ids, test_tokens_mask, test_segment_ids,
            test_masked_lm_positions
        ], [test_masked_lm_label, test_is_random_next]),
    )

    pretraining_model.load_weights("%s/%s" %
                                   (save_path, pretraining_model_name))
    bert_model = bert.get_bert_encoder()
    bert_model.save_weights("%s/%s" % (save_path, encoder_model_name))