Exemple #1
0
    def test_round_robin_correct_outputs(self):
        bpi = text_layers.BertPackInputs(10,
                                         start_of_sequence_id=1001,
                                         end_of_segment_id=1002,
                                         padding_id=999,
                                         truncator="round_robin")
        # Single input, rank 2.
        bert_inputs = bpi(
            tf.ragged.constant([[11, 12, 13],
                                [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]))
        self.assertAllEqual(
            bert_inputs["input_word_ids"],
            tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999],
                         [1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
        self.assertAllEqual(
            bert_inputs["input_mask"],
            tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
                         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
        self.assertAllEqual(
            bert_inputs["input_type_ids"],
            tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))

        # Two inputs, rank 3. Truncation does not respect word boundaries.
        bert_inputs = bpi([
            tf.ragged.constant([[[111], [112, 113]],
                                [[121, 122, 123], [124, 125, 126], [127,
                                                                    128]]]),
            tf.ragged.constant([[[211, 212], [213]],
                                [[221, 222], [223, 224, 225], [226, 227,
                                                               228]]])
        ])
        self.assertAllEqual(
            bert_inputs["input_word_ids"],
            tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 213, 1002, 999],
                         [1001, 121, 122, 123, 124, 1002, 221, 222, 223,
                          1002]]))
        self.assertAllEqual(
            bert_inputs["input_mask"],
            tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
                         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
        self.assertAllEqual(
            bert_inputs["input_type_ids"],
            tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
                         [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]]))

        # Three inputs has not been supported for round_robin so far.
        with self.assertRaisesRegex(ValueError, "Must pass 1 or 2 inputs"):
            bert_inputs = bpi([
                tf.ragged.constant([[[111], [112, 113]],
                                    [[121, 122, 123], [124, 125, 126],
                                     [127, 128]]]),
                tf.ragged.constant([[[211, 212], [213]],
                                    [[221, 222], [223, 224, 225],
                                     [226, 227, 228]]]),
                tf.ragged.constant([[[311, 312], [313]],
                                    [[321, 322], [323, 324, 325],
                                     [326, 327, 328]]])
            ])
Exemple #2
0
 def test_special_tokens_dict(self):
     special_tokens_dict = dict(start_of_sequence_id=1001,
                                end_of_segment_id=1002,
                                padding_id=999,
                                extraneous_key=666)
     bpi = text_layers.BertPackInputs(
         10, special_tokens_dict=special_tokens_dict)
     bert_inputs = bpi(
         tf.ragged.constant([[11, 12, 13],
                             [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]))
     self.assertAllEqual(
         bert_inputs["input_word_ids"],
         tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999],
                      [1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
Exemple #3
0
 def input_fn():
   with tf.init_scope():
     self.assertFalse(tf.executing_eagerly())
   # Build a preprocessing Model.
   sentences = tf.keras.layers.Input(shape=[], dtype=tf.string)
   sentencepiece_tokenizer = text_layers.SentencepieceTokenizer(
       model_file_path=self._spm_path, lower_case=True, nbest_size=0)
   special_tokens_dict = sentencepiece_tokenizer.get_special_tokens_dict()
   for k, v in special_tokens_dict.items():
     self.assertIsInstance(v, int, "Unexpected type for {}".format(k))
   tokens = sentencepiece_tokenizer(sentences)
   packed_inputs = text_layers.BertPackInputs(
       4, special_tokens_dict=special_tokens_dict)(tokens)
   preprocessing = tf.keras.Model(sentences, packed_inputs)
   # Map the dataset.
   ds = tf.data.Dataset.from_tensors(
       (tf.constant(["abc", "DEF"]), tf.constant([0, 1])))
   ds = ds.map(lambda features, labels: (preprocessing(features), labels))
   return ds
Exemple #4
0
    def test_waterfall_correct_outputs(self):
        bpi = text_layers.BertPackInputs(10,
                                         start_of_sequence_id=1001,
                                         end_of_segment_id=1002,
                                         padding_id=999,
                                         truncator="waterfall")
        # Single input, rank 2.
        bert_inputs = bpi(
            tf.ragged.constant([[11, 12, 13],
                                [21, 22, 23, 24, 25, 26, 27, 28, 29, 30]]))
        self.assertAllEqual(
            bert_inputs["input_word_ids"],
            tf.constant([[1001, 11, 12, 13, 1002, 999, 999, 999, 999, 999],
                         [1001, 21, 22, 23, 24, 25, 26, 27, 28, 1002]]))
        self.assertAllEqual(
            bert_inputs["input_mask"],
            tf.constant([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
                         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
        self.assertAllEqual(
            bert_inputs["input_type_ids"],
            tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))

        # Two inputs, rank 3. Truncation does not respect word boundaries.
        bert_inputs = bpi([
            tf.ragged.constant([[[111], [112, 113]],
                                [[121, 122, 123], [124, 125, 126], [127,
                                                                    128]]]),
            tf.ragged.constant([[[211, 212], [213]],
                                [[221, 222], [223, 224, 225], [226, 227,
                                                               228]]])
        ])
        self.assertAllEqual(
            bert_inputs["input_word_ids"],
            tf.constant([[1001, 111, 112, 113, 1002, 211, 212, 213, 1002, 999],
                         [1001, 121, 122, 123, 124, 125, 126, 127, 1002,
                          1002]]))
        self.assertAllEqual(
            bert_inputs["input_mask"],
            tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
                         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
        self.assertAllEqual(
            bert_inputs["input_type_ids"],
            tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 1, 0],
                         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]))

        # Three inputs, rank 3. Truncation does not respect word boundaries.
        bert_inputs = bpi([
            tf.ragged.constant([[[111], [112, 113]],
                                [[121, 122, 123], [124, 125, 126], [127,
                                                                    128]]]),
            tf.ragged.constant([[[211], [212]],
                                [[221, 222], [223, 224, 225], [226, 227,
                                                               228]]]),
            tf.ragged.constant([[[311, 312], [313]],
                                [[321, 322], [323, 324, 325], [326, 327]]])
        ])
        self.assertAllEqual(
            bert_inputs["input_word_ids"],
            tf.constant(
                [[1001, 111, 112, 113, 1002, 211, 212, 1002, 311, 1002],
                 [1001, 121, 122, 123, 124, 125, 126, 1002, 1002, 1002]]))
        self.assertAllEqual(
            bert_inputs["input_mask"],
            tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))
        self.assertAllEqual(
            bert_inputs["input_type_ids"],
            tf.constant([[0, 0, 0, 0, 0, 1, 1, 1, 2, 2],
                         [0, 0, 0, 0, 0, 0, 0, 0, 1, 2]]))