def test_round_robin_correct_outputs(self): bpi = text_layers.BertPackInputs(10, start_of_sequence_id=1001, end_of_segment_id=1002, padding_id=999, truncator="round_robin") # Single input, rank 2. bert_inputs = bpi( tf.ragged.constant([[11, 12, 13], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]])) self.assertAllEqual( bert_inputs["input_word_ids"], tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999], [1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]])) self.assertAllEqual( bert_inputs["input_mask"], tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])) self.assertAllEqual( bert_inputs["input_type_ids"], tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])) # Two inputs, rank 3. Truncation does not respect word boundaries. bert_inputs = bpi([ tf.ragged.constant([[[111], [112, 113]], [[121, 122, 123], [124, 125, 126], [127, 128]]]), tf.ragged.constant([[[211, 212], [213]], [[221, 222], [223, 224, 225], [226, 227, 228]]]) ]) self.assertAllEqual( bert_inputs["input_word_ids"], tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 213, 1002, 999], [1001, 121, 122, 123, 124, 1002, 221, 222, 223, 1002]])) self.assertAllEqual( bert_inputs["input_mask"], tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])) self.assertAllEqual( bert_inputs["input_type_ids"], tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])) # Three inputs has not been supported for round_robin so far. with self.assertRaisesRegex(ValueError, "Must pass 1 or 2 inputs"): bert_inputs = bpi([ tf.ragged.constant([[[111], [112, 113]], [[121, 122, 123], [124, 125, 126], [127, 128]]]), tf.ragged.constant([[[211, 212], [213]], [[221, 222], [223, 224, 225], [226, 227, 228]]]), tf.ragged.constant([[[311, 312], [313]], [[321, 322], [323, 324, 325], [326, 327, 328]]]) ])
def test_special_tokens_dict(self): special_tokens_dict = dict(start_of_sequence_id=1001, end_of_segment_id=1002, padding_id=999, extraneous_key=666) bpi = text_layers.BertPackInputs( 10, special_tokens_dict=special_tokens_dict) bert_inputs = bpi( tf.ragged.constant([[11, 12, 13], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]])) self.assertAllEqual( bert_inputs["input_word_ids"], tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999], [1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
def input_fn(): with tf.init_scope(): self.assertFalse(tf.executing_eagerly()) # Build a preprocessing Model. sentences = tf.keras.layers.Input(shape=[], dtype=tf.string) sentencepiece_tokenizer = text_layers.SentencepieceTokenizer( model_file_path=self._spm_path, lower_case=True, nbest_size=0) special_tokens_dict = sentencepiece_tokenizer.get_special_tokens_dict() for k, v in special_tokens_dict.items(): self.assertIsInstance(v, int, "Unexpected type for {}".format(k)) tokens = sentencepiece_tokenizer(sentences) packed_inputs = text_layers.BertPackInputs( 4, special_tokens_dict=special_tokens_dict)(tokens) preprocessing = tf.keras.Model(sentences, packed_inputs) # Map the dataset. ds = tf.data.Dataset.from_tensors( (tf.constant(["abc", "DEF"]), tf.constant([0, 1]))) ds = ds.map(lambda features, labels: (preprocessing(features), labels)) return ds
def test_waterfall_correct_outputs(self): bpi = text_layers.BertPackInputs(10, start_of_sequence_id=1001, end_of_segment_id=1002, padding_id=999, truncator="waterfall") # Single input, rank 2. bert_inputs = bpi( tf.ragged.constant([[11, 12, 13], [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]])) self.assertAllEqual( bert_inputs["input_word_ids"], tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999], [1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]])) self.assertAllEqual( bert_inputs["input_mask"], tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])) self.assertAllEqual( bert_inputs["input_type_ids"], tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])) # Two inputs, rank 3. Truncation does not respect word boundaries. bert_inputs = bpi([ tf.ragged.constant([[[111], [112, 113]], [[121, 122, 123], [124, 125, 126], [127, 128]]]), tf.ragged.constant([[[211, 212], [213]], [[221, 222], [223, 224, 225], [226, 227, 228]]]) ]) self.assertAllEqual( bert_inputs["input_word_ids"], tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 213, 1002, 999], [1001, 121, 122, 123, 124, 125, 126, 127, 1002, 1002]])) self.assertAllEqual( bert_inputs["input_mask"], tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])) self.assertAllEqual( bert_inputs["input_type_ids"], tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])) # Three inputs, rank 3. Truncation does not respect word boundaries. bert_inputs = bpi([ tf.ragged.constant([[[111], [112, 113]], [[121, 122, 123], [124, 125, 126], [127, 128]]]), tf.ragged.constant([[[211], [212]], [[221, 222], [223, 224, 225], [226, 227, 228]]]), tf.ragged.constant([[[311, 312], [313]], [[321, 322], [323, 324, 325], [326, 327]]]) ]) self.assertAllEqual( bert_inputs["input_word_ids"], tf.constant( [[1001, 111, 112, 113, 1002, 211, 212, 1002, 311, 1002], [1001, 121, 122, 123, 124, 125, 126, 1002, 1002, 1002]])) self.assertAllEqual( bert_inputs["input_mask"], tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])) self.assertAllEqual( bert_inputs["input_type_ids"], tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 2, 2], [0, 0, 0, 0, 0, 0, 0, 0, 1, 2]]))