class text_classifier(object): def __init__(self, bert_config, pretrain_model, batch_size, seq_length, optimizer, num_classes, use_token_type=True, mask=True, max_predictions_per_seq=20, multi_gpu=None): if not isinstance(bert_config, BertConfig): raise ValueError( "`bert_config` must be a instance of `BertConfig`") if multi_gpu: if not tf.test.is_gpu_available: raise ValueError("GPU is not available.") self.config = bert_config self.batch_size = batch_size self.seq_length = seq_length self.use_token_type = use_token_type self.max_predictions_per_seq = max_predictions_per_seq self.mask = mask self.num_classes = num_classes if multi_gpu: with tf.device('/cpu:0'): model = self._build_model(pretrain_model) model.compile(optimizer=optimizer, loss=losses.categorical_crossentropy) parallel_model = multi_gpu_model(model=model, gpus=multi_gpu) parallel_model.compile(optimizer=optimizer, loss=losses.categorical_crossentropy) else: model = self._build_model(pretrain_model) model.compile(optimizer=optimizer, loss=losses.categorical_crossentropy) self.estimator = model if multi_gpu: self.estimator = model def fit(self, x, y, epochs, shuffle=True, callbacks=None, validation_split=0., validation_data=None, class_weight=None, sample_weight=None, **kwargs): self.estimator.fit(x=x, y=y, batch_size=self.batch_size, epochs=epochs, shuffle=shuffle, callbacks=callbacks, validation_split=validation_split, validation_data=validation_data, class_weight=class_weight, sample_weight=sample_weight, **kwargs) def predict(self, x, batch_size=None, verbose=0, steps=None): result = self.estimator.predict(x=x, batch_size=batch_size, verbose=verbose, steps=steps) return result def _build_model(self, pretrain_model): input_ids = Input(shape=(self.seq_length, )) input_mask = Input(shape=(self.seq_length, )) inputs = [input_ids, input_mask] if self.use_token_type: input_token_type_ids = Input(shape=(self.seq_length, )) inputs.append(input_token_type_ids) self.bert = BertModel( self.config, batch_size=self.batch_size, seq_length=self.seq_length, max_predictions_per_seq=self.max_predictions_per_seq, use_token_type=self.use_token_type, mask=self.mask) self.bert_encoder = self.bert.get_bert_encoder() self.bert_encoder.load_weights(pretrain_model) pooled_output = self.bert_encoder(inputs) pooled_output = Dropout(self.config.hidden_dropout_prob)(pooled_output) pred = Dense(units=self.num_classes, activation='softmax', kernel_initializer=initializers.truncated_normal( stddev=self.config.initializer_range))(pooled_output) model = Model(inputs=inputs, outputs=pred) return model
def bert_pretraining(train_data_path, bert_config_file, save_path, batch_size=32, epochs=2, seq_length=128, max_predictions_per_seq=20, lr=5e-5, num_warmup_steps=10000, checkpoints_interval_steps=1000, weight_decay_rate=0.01, validation_ratio=0.1, max_num_val=10000, multi_gpu=0, val_batch_size=None, pretraining_model_name='bert_pretraining.h5', encoder_model_name='bert_encoder.h5', random_state=None): '''masked LM/next sentence masked_lm pre-training for BERT. # Args train_data_path: path of train data. bert_config_file: The config json file corresponding to the pre-trained BERT model. This specifies the model architecture. save_path: dir to save checkpoints. batch_size: Integer. Number of samples per gradient update. epochs: Integer. Number of epochs to train the model. seq_length: The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded. Must match data generation. max_predictions_per_seq:Integer. Maximum number of masked LM predictions per sequence. lr: float >= 0. Learning rate. num_warmup_steps: Integer. Number of warm up steps. checkpoints_interval_steps: Integer. interval of checkpoints. only enable after model is warmed up. weight_decay_rate: float. value of weight decay rate. validation_ratio: Float between 0 and 1. Fraction of the training data to be used as validation data. The model will set apart this fraction of the training data, will not train on it, and will evaluate the loss and any model metrics. max_num_val: Integer. max number of validation data. when multi_gpu > 0, model will use cpu to evaluate validation data. Controlling the argument can benefit the efficiency of validation process. multi_gpu: Integer. when multi_gpu > 0, cpu will be use to merge model trained in gpus. val_batch_size: Integer. Number of samples used in validation step. If `val_batch_size` is None, val_batch_size will be equal to `batch_size`. pretraining_model_name: name of pretraining model file. encoder_model_name: name of encoder model file. random_state : int, RandomState instance or None, optional (default=None) If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by `np.random`. ''' if multi_gpu > 0: if not tf.test.is_gpu_available: raise ValueError("GPU is not available. Set `multi_gpu` to be 0.") pre_training_data = np.load(train_data_path) tokens_ids = pre_training_data['tokens_ids'] tokens_mask = pre_training_data['tokens_mask'] segment_ids = pre_training_data['segment_ids'] is_random_next = pre_training_data['is_random_next'] masked_lm_positions = pre_training_data['masked_lm_positions'] masked_lm_label = pre_training_data['masked_lm_labels'] num_train_samples = int(len(tokens_ids) * (1 - validation_ratio)) num_train_steps = int(np.ceil(num_train_samples / batch_size)) * epochs logging.info('train steps: {}'.format(num_train_steps)) logging.info('train samples: {}'.format(tokens_ids)) if num_train_steps < num_warmup_steps + checkpoints_interval_steps: raise ValueError( "number of train steps must be larger than the sum of" " `num_warmup_steps` and `checkpoints_interval_steps`." "enlarge your train data or reduce batch_size") warmup_ratio = num_warmup_steps / num_train_steps if warmup_ratio > 0.02: warnings.warn( "model performance may be suitable when warmup steps is 0.01~0.02 of train steps.", UserWarning) config = BertConfig.from_json_file(bert_config_file) num_val = int(len(tokens_ids) * validation_ratio) if num_val > max_num_val: validation_ratio = max_num_val / len(tokens_ids) # split data for train and valid sss = StratifiedShuffleSplit(n_splits=1, test_size=validation_ratio, random_state=random_state) for train_index, test_index in sss.split(tokens_ids, is_random_next): train_tokens_ids, test_tokens_ids = tokens_ids[ train_index], tokens_ids[test_index] train_tokens_mask, test_tokens_mask = tokens_mask[ train_index], tokens_mask[test_index] train_segment_ids, test_segment_ids = segment_ids[ train_index], segment_ids[test_index] train_is_random_next, test_is_random_next = is_random_next[ train_index], is_random_next[test_index] train_masked_lm_positions, test_masked_lm_positions = masked_lm_positions[ train_index], masked_lm_positions[test_index] train_masked_lm_label, test_masked_lm_label = masked_lm_label[ train_index], masked_lm_label[test_index] test_masked_lm_label = test_masked_lm_label.reshape((-1)) test_masked_lm_label = to_categorical(test_masked_lm_label, num_classes=config.vocab_size) test_masked_lm_label = test_masked_lm_label.reshape( (-1, max_predictions_per_seq, config.vocab_size)) logging.info("build pretraining nnet...") adam = AdamWeightDecayOpt( lr=lr, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, beta_1=0.9, beta_2=0.999, epsilon=1e-6, weight_decay_rate=weight_decay_rate, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) bert = BertModel(config, batch_size=batch_size, seq_length=seq_length, max_predictions_per_seq=max_predictions_per_seq, use_token_type=True, embeddings_matrix=None, mask=True) if multi_gpu: # To avoid OOM errors, this model could have been built on CPU with tf.device('/cpu:0'): pretraining_model = bert.get_pretraining_model() pretraining_model.compile(optimizer=adam, loss=losses.categorical_crossentropy, metrics=['acc'], loss_weights=[0.5, 0.5]) parallel_pretraining_model = multi_gpu_model(model=pretraining_model, gpus=multi_gpu) parallel_pretraining_model.compile( optimizer=adam, loss=losses.categorical_crossentropy, metrics=['acc'], loss_weights=[0.5, 0.5]) else: pretraining_model = bert.get_pretraining_model() pretraining_model.compile(optimizer=adam, loss=losses.categorical_crossentropy, metrics=['acc'], loss_weights=[0.5, 0.5]) logging.info('training pretraining nnet for {} epochs'.format(epochs)) train_sample_generator = SampleSequence( x=[ train_tokens_ids, train_tokens_mask, train_segment_ids, train_masked_lm_positions ], y=[train_masked_lm_label, train_is_random_next], batch_size=batch_size, vocab_size=config.vocab_size, max_predictions_per_seq=max_predictions_per_seq) checkpoint_model = None if multi_gpu: checkpoint_model = pretraining_model if not os.path.exists(save_path): os.mkdir(save_path) checkpoint = StepPreTrainModelCheckpoint( filepath="%s/%s" % (save_path, pretraining_model_name), start_step=num_warmup_steps, period=checkpoints_interval_steps, save_best_only=True, verbose=1, val_batch_size=val_batch_size, model= checkpoint_model #when use multi_gpu_model, set model to the original model ) estimator = pretraining_model if multi_gpu: estimator = parallel_pretraining_model estimator.fit_generator( generator=train_sample_generator, epochs=epochs, callbacks=[checkpoint], shuffle=False, validation_data=([ test_tokens_ids, test_tokens_mask, test_segment_ids, test_masked_lm_positions ], [test_masked_lm_label, test_is_random_next]), ) pretraining_model.load_weights("%s/%s" % (save_path, pretraining_model_name)) bert_model = bert.get_bert_encoder() bert_model.save_weights("%s/%s" % (save_path, encoder_model_name))