def __init__(self, params): if bool(params.tfds_name) == bool(params.input_path): raise ValueError( 'Must specify either `tfds_name` and `tfds_split` ' 'or `input_path`.') if bool(params.vocab_file) == bool( params.preprocessing_hub_module_url): raise ValueError( 'Must specify exactly one of vocab_file (with matching ' 'lower_case flag) or preprocessing_hub_module_url.') self._params = params self._seq_length = params.seq_length self._left_text_fields = params.left_text_fields self._right_text_fields = params.right_text_fields if params.preprocessing_hub_module_url: preprocessing_hub_module = hub.load( params.preprocessing_hub_module_url) self._tokenizer = preprocessing_hub_module.tokenize self._pack_inputs = functools.partial( preprocessing_hub_module.bert_pack_inputs, seq_length=params.seq_length) else: self._tokenizer = layers.BertTokenizer( vocab_file=params.vocab_file, lower_case=params.lower_case) self._pack_inputs = layers.BertPackInputs( seq_length=params.seq_length, special_tokens_dict=self._tokenizer.get_special_tokens_dict())
def create_preprocessing(*, vocab_file: Optional[str] = None, sp_model_file: Optional[str] = None, do_lower_case: bool, tokenize_with_offsets: bool, default_seq_length: int) -> tf.keras.Model: """Returns a preprocessing Model for given tokenization parameters. This function builds a Keras Model with attached subobjects suitable for saving to a SavedModel. The resulting SavedModel implements the Preprocessor API for Text embeddings with Transformer Encoders described at https://www.tensorflow.org/hub/common_saved_model_apis/text. Args: vocab_file: The path to the wordpiece vocab file, or None. sp_model_file: The path to the sentencepiece model file, or None. Exactly one of vocab_file and sp_model_file must be set. This determines the type of tokenzer that is used. do_lower_case: Whether to do lower case. tokenize_with_offsets: Whether to include the .tokenize_with_offsets subobject. default_seq_length: The sequence length of preprocessing results from root callable. This is also the default sequence length for the bert_pack_inputs subobject. Returns: A tf.keras.Model object with several attached subobjects, suitable for saving as a preprocessing SavedModel. """ # Select tokenizer. if bool(vocab_file) == bool(sp_model_file): raise ValueError("Must set exactly one of vocab_file, sp_model_file") if vocab_file: tokenize = layers.BertTokenizer( vocab_file=vocab_file, lower_case=do_lower_case, tokenize_with_offsets=tokenize_with_offsets) else: tokenize = layers.SentencepieceTokenizer( model_file_path=sp_model_file, lower_case=do_lower_case, strip_diacritics=True, # Strip diacritics to follow ALBERT model. tokenize_with_offsets=tokenize_with_offsets) # The root object of the preprocessing model can be called to do # one-shot preprocessing for users with single-sentence inputs. sentences = tf.keras.layers.Input(shape=(), dtype=tf.string, name="sentences") if tokenize_with_offsets: tokens, start_offsets, limit_offsets = tokenize(sentences) else: tokens = tokenize(sentences) pack = layers.BertPackInputs( seq_length=default_seq_length, special_tokens_dict=tokenize.get_special_tokens_dict()) model_inputs = pack(tokens) preprocessing = tf.keras.Model(sentences, model_inputs) # Individual steps of preprocessing are made available as named subobjects # to enable more general preprocessing. For saving, they need to be Models # in their own right. preprocessing.tokenize = tf.keras.Model(sentences, tokens) # Provide an equivalent to tokenize.get_special_tokens_dict(). preprocessing.tokenize.get_special_tokens_dict = tf.train.Checkpoint() preprocessing.tokenize.get_special_tokens_dict.__call__ = tf.function( lambda: tokenize.get_special_tokens_dict(), # pylint: disable=[unnecessary-lambda] input_signature=[]) if tokenize_with_offsets: preprocessing.tokenize_with_offsets = tf.keras.Model( sentences, [tokens, start_offsets, limit_offsets]) preprocessing.tokenize_with_offsets.get_special_tokens_dict = ( preprocessing.tokenize.get_special_tokens_dict) # Conceptually, this should be # preprocessing.bert_pack_inputs = tf.keras.Model(tokens, model_inputs) # but technicalities require us to use a wrapper (see comments there). # In particular, seq_length can be overridden when calling this. preprocessing.bert_pack_inputs = BertPackInputsSavedModelWrapper(pack) return preprocessing