def bert_pack_inputs(inputs: Union[tf.RaggedTensor, List[tf.RaggedTensor]], seq_length: Union[int, tf.Tensor], start_of_sequence_id: Union[int, tf.Tensor], end_of_segment_id: Union[int, tf.Tensor], padding_id: Union[int, tf.Tensor], truncator="round_robin"): """Freestanding equivalent of the BertPackInputs layer.""" _check_if_tf_text_installed() # Sanitize inputs. if not isinstance(inputs, (list, tuple)): inputs = [inputs] if not inputs: raise ValueError("At least one input is required for packing") input_ranks = [rt.shape.rank for rt in inputs] if None in input_ranks or len(set(input_ranks)) > 1: raise ValueError( "All inputs for packing must have the same known rank, " "found ranks " + ",".join(input_ranks)) # Flatten inputs to [batch_size, (tokens)]. if input_ranks[0] > 2: inputs = [rt.merge_dims(1, -1) for rt in inputs] # In case inputs weren't truncated (as they should have been), # fall back to some ad-hoc truncation. num_special_tokens = len(inputs) + 1 if truncator == "round_robin": trimmed_segments = text.RoundRobinTrimmer( seq_length - num_special_tokens).trim(inputs) elif truncator == "waterfall": trimmed_segments = text.WaterfallTrimmer( seq_length - num_special_tokens).trim(inputs) else: raise ValueError("Unsupported truncator: %s" % truncator) # Combine segments. segments_combined, segment_ids = text.combine_segments( trimmed_segments, start_of_sequence_id=start_of_sequence_id, end_of_segment_id=end_of_segment_id) # Pad to dense Tensors. input_word_ids, _ = text.pad_model_inputs(segments_combined, seq_length, pad_value=padding_id) input_type_ids, input_mask = text.pad_model_inputs(segment_ids, seq_length, pad_value=0) # Work around broken shape inference. output_shape = tf.stack([ inputs[0].nrows(out_type=tf.int32), # batch_size tf.cast(seq_length, dtype=tf.int32) ]) def _reshape(t): return tf.reshape(t, output_shape) # Assemble nest of input tensors as expected by BERT TransformerEncoder. return dict(input_word_ids=_reshape(input_word_ids), input_mask=_reshape(input_mask), input_type_ids=_reshape(input_type_ids))
def call(self, inputs: tf.Tensor | list[tf.Tensor]) -> list[tf.Tensor]: inputs = tf.nest.flatten(inputs) trimmer = tftext.WaterfallTrimmer(self.max_length) if len(inputs) == 1: if self.return_offset: (token_ids, starts, ends) = self.tokenizer.tokenize_with_offsets(inputs[0]) starts = starts.merge_dims(-2, -1) ends = ends.merge_dims(-2, -1) else: token_ids = self.tokenizer.tokenize(inputs[0]) flatten_ids = token_ids.merge_dims(-2, -1) ids, type_ids = tftext.combine_segments( trimmer.trim([flatten_ids]), start_of_sequence_id=self.cls_id, end_of_segment_id=self.sep_id, ) tensors: list[tf.Tensor] = [ids.to_tensor(), type_ids.to_tensor()] if self.return_offset: tensors.append(starts.to_tensor()) tensors.append(ends.to_tensor()) return tensors elif len(inputs) == 2: query, context = inputs query_token_ids = self.tokenizer.tokenize(query) if self.return_offset: # TODO: 句对输入仅输出第二句的start和end索引是否合理, # 需要根据具体任务设计判断, 如阅读理解 ( context_token_ids, starts, ends, ) = self.tokenizer.tokenize_with_offsets(context) starts = starts.merge_dims(-2, -1) ends = ends.merge_dims(-2, -1) else: context_token_ids = self.tokenizer.tokenize(context) query_flatten_ids = query_token_ids.merge_dims(-2, -1) context_flatten_ids = context_token_ids.merge_dims(-2, -1) token_ids, type_ids = tftext.combine_segments( trimmer.trim([query_flatten_ids, context_flatten_ids]), start_of_sequence_id=self.cls_id, end_of_segment_id=self.sep_id, ) tensors: list[tf.Tensor] = [ token_ids.to_tensor(), type_ids.to_tensor() ] if self.return_offset: tensors.append(starts.to_tensor()) tensors.append(ends.to_tensor()) return tensors else: raise ValueError( f"The length of inputs must be 1 or 2, bug get {len(inputs)}.")
def test_preprocessing_for_mlm(self, use_bert): """Combines both SavedModel types and TF.text helpers for MLM.""" # Create the preprocessing SavedModel with a [MASK] token. non_special_tokens = [ "hello", "world", "nice", "movie", "great", "actors", "quick", "fox", "lazy", "dog" ] preprocess = tf.saved_model.load( self._do_export( non_special_tokens, do_lower_case=True, tokenize_with_offsets=use_bert, # TODO(b/181866850): drop this. experimental_disable_assert= True, # TODO(b/175369555): drop this. add_mask_token=True, use_sp_model=not use_bert)) vocab_size = len(non_special_tokens) + (5 if use_bert else 7) # Create the encoder SavedModel with an .mlm subobject. hidden_size = 16 num_hidden_layers = 2 bert_config, encoder_config = _get_bert_config_or_encoder_config( use_bert, hidden_size, num_hidden_layers, vocab_size) _, pretrainer = export_tfhub_lib._create_model( bert_config=bert_config, encoder_config=encoder_config, with_mlm=True) model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint") checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items) checkpoint.save(os.path.join(model_checkpoint_dir, "test")) model_checkpoint_path = tf.train.latest_checkpoint( model_checkpoint_dir) vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy( # Not used below. self.get_temp_dir(), use_sp_model=not use_bert) encoder_export_path = os.path.join(self.get_temp_dir(), "encoder_export") export_tfhub_lib.export_model( export_path=encoder_export_path, bert_config=bert_config, encoder_config=encoder_config, model_checkpoint_path=model_checkpoint_path, with_mlm=True, vocab_file=vocab_file, sp_model_file=sp_model_file, do_lower_case=True) encoder = tf.saved_model.load(encoder_export_path) # Get special tokens from the vocab (and vocab size). special_tokens_dict = preprocess.tokenize.get_special_tokens_dict() self.assertEqual(int(special_tokens_dict["vocab_size"]), vocab_size) padding_id = int(special_tokens_dict["padding_id"]) self.assertEqual(padding_id, 0) start_of_sequence_id = int(special_tokens_dict["start_of_sequence_id"]) self.assertEqual(start_of_sequence_id, 2) end_of_segment_id = int(special_tokens_dict["end_of_segment_id"]) self.assertEqual(end_of_segment_id, 3) mask_id = int(special_tokens_dict["mask_id"]) self.assertEqual(mask_id, 4) # A batch of 3 segment pairs. raw_segments = [ tf.constant(["hello", "nice movie", "quick fox"]), tf.constant(["world", "great actors", "lazy dog"]) ] batch_size = 3 # Misc hyperparameters. seq_length = 10 max_selections_per_seq = 2 # Tokenize inputs. tokenized_segments = [preprocess.tokenize(s) for s in raw_segments] # Trim inputs to eventually fit seq_lentgh. num_special_tokens = len(raw_segments) + 1 trimmed_segments = text.WaterfallTrimmer( seq_length - num_special_tokens).trim(tokenized_segments) # Combine input segments into one input sequence. input_ids, segment_ids = text.combine_segments( trimmed_segments, start_of_sequence_id=start_of_sequence_id, end_of_segment_id=end_of_segment_id) # Apply random masking controlled by policy objects. (masked_input_ids, masked_lm_positions, masked_ids) = text.mask_language_model( input_ids=input_ids, item_selector=text.RandomItemSelector( max_selections_per_seq, selection_rate=0.5, # Adjusted for the short test examples. unselectable_ids=[start_of_sequence_id, end_of_segment_id]), mask_values_chooser=text.MaskValuesChooser( vocab_size=vocab_size, mask_token=mask_id, # Always put [MASK] to have a predictable result. mask_token_rate=1.0, random_token_rate=0.0)) # Pad to fixed-length Transformer encoder inputs. input_word_ids, _ = text.pad_model_inputs(masked_input_ids, seq_length, pad_value=padding_id) input_type_ids, input_mask = text.pad_model_inputs(segment_ids, seq_length, pad_value=0) masked_lm_positions, _ = text.pad_model_inputs(masked_lm_positions, max_selections_per_seq, pad_value=0) masked_lm_positions = tf.cast(masked_lm_positions, tf.int32) num_predictions = int(tf.shape(masked_lm_positions)[1]) # Test transformer inputs. self.assertEqual(num_predictions, max_selections_per_seq) expected_word_ids = np.array([ # [CLS] hello [SEP] world [SEP] [2, 5, 3, 6, 3, 0, 0, 0, 0, 0], # [CLS] nice movie [SEP] great actors [SEP] [2, 7, 8, 3, 9, 10, 3, 0, 0, 0], # [CLS] brown fox [SEP] lazy dog [SEP] [2, 11, 12, 3, 13, 14, 3, 0, 0, 0] ]) for i in range(batch_size): for j in range(num_predictions): k = int(masked_lm_positions[i, j]) if k != 0: expected_word_ids[i, k] = 4 # [MASK] self.assertAllEqual(input_word_ids, expected_word_ids) # Call the MLM head of the Transformer encoder. mlm_inputs = dict( input_word_ids=input_word_ids, input_mask=input_mask, input_type_ids=input_type_ids, masked_lm_positions=masked_lm_positions, ) mlm_outputs = encoder.mlm(mlm_inputs) self.assertEqual(mlm_outputs["pooled_output"].shape, (batch_size, hidden_size)) self.assertEqual(mlm_outputs["sequence_output"].shape, (batch_size, seq_length, hidden_size)) self.assertEqual(mlm_outputs["mlm_logits"].shape, (batch_size, num_predictions, vocab_size)) self.assertLen(mlm_outputs["encoder_outputs"], num_hidden_layers) # A real trainer would now compute the loss of mlm_logits # trying to predict the masked_ids. del masked_ids # Unused.