Exemplo n.º 1
0
    def call(self, inputs: tf.Tensor | list[tf.Tensor]) -> list[tf.Tensor]:
        inputs = tf.nest.flatten(inputs)
        trimmer = tftext.WaterfallTrimmer(self.max_length)
        if len(inputs) == 1:
            if self.return_offset:
                (token_ids, starts,
                 ends) = self.tokenizer.tokenize_with_offsets(inputs[0])
                starts = starts.merge_dims(-2, -1)
                ends = ends.merge_dims(-2, -1)
            else:
                token_ids = self.tokenizer.tokenize(inputs[0])
            flatten_ids = token_ids.merge_dims(-2, -1)
            ids, type_ids = tftext.combine_segments(
                trimmer.trim([flatten_ids]),
                start_of_sequence_id=self.cls_id,
                end_of_segment_id=self.sep_id,
            )
            tensors: list[tf.Tensor] = [ids.to_tensor(), type_ids.to_tensor()]
            if self.return_offset:
                tensors.append(starts.to_tensor())
                tensors.append(ends.to_tensor())
            return tensors
        elif len(inputs) == 2:
            query, context = inputs
            query_token_ids = self.tokenizer.tokenize(query)
            if self.return_offset:
                # TODO: 句对输入仅输出第二句的start和end索引是否合理,
                #       需要根据具体任务设计判断, 如阅读理解
                (
                    context_token_ids,
                    starts,
                    ends,
                ) = self.tokenizer.tokenize_with_offsets(context)
                starts = starts.merge_dims(-2, -1)
                ends = ends.merge_dims(-2, -1)
            else:
                context_token_ids = self.tokenizer.tokenize(context)
            query_flatten_ids = query_token_ids.merge_dims(-2, -1)
            context_flatten_ids = context_token_ids.merge_dims(-2, -1)
            token_ids, type_ids = tftext.combine_segments(
                trimmer.trim([query_flatten_ids, context_flatten_ids]),
                start_of_sequence_id=self.cls_id,
                end_of_segment_id=self.sep_id,
            )

            tensors: list[tf.Tensor] = [
                token_ids.to_tensor(),
                type_ids.to_tensor()
            ]
            if self.return_offset:
                tensors.append(starts.to_tensor())
                tensors.append(ends.to_tensor())
            return tensors
        else:
            raise ValueError(
                f"The length of inputs must be 1 or 2, bug get {len(inputs)}.")
Exemplo n.º 2
0
    def bert_pack_inputs(inputs: Union[tf.RaggedTensor, List[tf.RaggedTensor]],
                         seq_length: Union[int, tf.Tensor],
                         start_of_sequence_id: Union[int, tf.Tensor],
                         end_of_segment_id: Union[int, tf.Tensor],
                         padding_id: Union[int, tf.Tensor],
                         truncator="round_robin"):
        """Freestanding equivalent of the BertPackInputs layer."""
        _check_if_tf_text_installed()
        # Sanitize inputs.
        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]
        if not inputs:
            raise ValueError("At least one input is required for packing")
        input_ranks = [rt.shape.rank for rt in inputs]
        if None in input_ranks or len(set(input_ranks)) > 1:
            raise ValueError(
                "All inputs for packing must have the same known rank, "
                "found ranks " + ",".join(input_ranks))
        # Flatten inputs to [batch_size, (tokens)].
        if input_ranks[0] > 2:
            inputs = [rt.merge_dims(1, -1) for rt in inputs]
        # In case inputs weren't truncated (as they should have been),
        # fall back to some ad-hoc truncation.
        num_special_tokens = len(inputs) + 1
        if truncator == "round_robin":
            trimmed_segments = text.RoundRobinTrimmer(
                seq_length - num_special_tokens).trim(inputs)
        elif truncator == "waterfall":
            trimmed_segments = text.WaterfallTrimmer(
                seq_length - num_special_tokens).trim(inputs)
        else:
            raise ValueError("Unsupported truncator: %s" % truncator)
        # Combine segments.
        segments_combined, segment_ids = text.combine_segments(
            trimmed_segments,
            start_of_sequence_id=start_of_sequence_id,
            end_of_segment_id=end_of_segment_id)
        # Pad to dense Tensors.
        input_word_ids, _ = text.pad_model_inputs(segments_combined,
                                                  seq_length,
                                                  pad_value=padding_id)
        input_type_ids, input_mask = text.pad_model_inputs(segment_ids,
                                                           seq_length,
                                                           pad_value=0)
        # Work around broken shape inference.
        output_shape = tf.stack([
            inputs[0].nrows(out_type=tf.int32),  # batch_size
            tf.cast(seq_length, dtype=tf.int32)
        ])

        def _reshape(t):
            return tf.reshape(t, output_shape)

        # Assemble nest of input tensors as expected by BERT TransformerEncoder.
        return dict(input_word_ids=_reshape(input_word_ids),
                    input_mask=_reshape(input_mask),
                    input_type_ids=_reshape(input_type_ids))
Exemplo n.º 3
0
    def test_preprocessing_for_mlm(self, use_bert):
        """Combines both SavedModel types and TF.text helpers for MLM."""
        # Create the preprocessing SavedModel with a [MASK] token.
        non_special_tokens = [
            "hello", "world", "nice", "movie", "great", "actors", "quick",
            "fox", "lazy", "dog"
        ]
        preprocess = tf.saved_model.load(
            self._do_export(
                non_special_tokens,
                do_lower_case=True,
                tokenize_with_offsets=use_bert,  # TODO(b/181866850): drop this.
                experimental_disable_assert=
                True,  # TODO(b/175369555): drop this.
                add_mask_token=True,
                use_sp_model=not use_bert))
        vocab_size = len(non_special_tokens) + (5 if use_bert else 7)

        # Create the encoder SavedModel with an .mlm subobject.
        hidden_size = 16
        num_hidden_layers = 2
        bert_config, encoder_config = _get_bert_config_or_encoder_config(
            use_bert, hidden_size, num_hidden_layers, vocab_size)
        _, pretrainer = export_tfhub_lib._create_model(
            bert_config=bert_config,
            encoder_config=encoder_config,
            with_mlm=True)
        model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
        checkpoint = tf.train.Checkpoint(**pretrainer.checkpoint_items)
        checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
        model_checkpoint_path = tf.train.latest_checkpoint(
            model_checkpoint_dir)
        vocab_file, sp_model_file = _get_vocab_or_sp_model_dummy(  # Not used below.
            self.get_temp_dir(), use_sp_model=not use_bert)
        encoder_export_path = os.path.join(self.get_temp_dir(),
                                           "encoder_export")
        export_tfhub_lib.export_model(
            export_path=encoder_export_path,
            bert_config=bert_config,
            encoder_config=encoder_config,
            model_checkpoint_path=model_checkpoint_path,
            with_mlm=True,
            vocab_file=vocab_file,
            sp_model_file=sp_model_file,
            do_lower_case=True)
        encoder = tf.saved_model.load(encoder_export_path)

        # Get special tokens from the vocab (and vocab size).
        special_tokens_dict = preprocess.tokenize.get_special_tokens_dict()
        self.assertEqual(int(special_tokens_dict["vocab_size"]), vocab_size)
        padding_id = int(special_tokens_dict["padding_id"])
        self.assertEqual(padding_id, 0)
        start_of_sequence_id = int(special_tokens_dict["start_of_sequence_id"])
        self.assertEqual(start_of_sequence_id, 2)
        end_of_segment_id = int(special_tokens_dict["end_of_segment_id"])
        self.assertEqual(end_of_segment_id, 3)
        mask_id = int(special_tokens_dict["mask_id"])
        self.assertEqual(mask_id, 4)

        # A batch of 3 segment pairs.
        raw_segments = [
            tf.constant(["hello", "nice movie", "quick fox"]),
            tf.constant(["world", "great actors", "lazy dog"])
        ]
        batch_size = 3

        # Misc hyperparameters.
        seq_length = 10
        max_selections_per_seq = 2

        # Tokenize inputs.
        tokenized_segments = [preprocess.tokenize(s) for s in raw_segments]
        # Trim inputs to eventually fit seq_lentgh.
        num_special_tokens = len(raw_segments) + 1
        trimmed_segments = text.WaterfallTrimmer(
            seq_length - num_special_tokens).trim(tokenized_segments)
        # Combine input segments into one input sequence.
        input_ids, segment_ids = text.combine_segments(
            trimmed_segments,
            start_of_sequence_id=start_of_sequence_id,
            end_of_segment_id=end_of_segment_id)
        # Apply random masking controlled by policy objects.
        (masked_input_ids, masked_lm_positions,
         masked_ids) = text.mask_language_model(
             input_ids=input_ids,
             item_selector=text.RandomItemSelector(
                 max_selections_per_seq,
                 selection_rate=0.5,  # Adjusted for the short test examples.
                 unselectable_ids=[start_of_sequence_id, end_of_segment_id]),
             mask_values_chooser=text.MaskValuesChooser(
                 vocab_size=vocab_size,
                 mask_token=mask_id,
                 # Always put [MASK] to have a predictable result.
                 mask_token_rate=1.0,
                 random_token_rate=0.0))
        # Pad to fixed-length Transformer encoder inputs.
        input_word_ids, _ = text.pad_model_inputs(masked_input_ids,
                                                  seq_length,
                                                  pad_value=padding_id)
        input_type_ids, input_mask = text.pad_model_inputs(segment_ids,
                                                           seq_length,
                                                           pad_value=0)
        masked_lm_positions, _ = text.pad_model_inputs(masked_lm_positions,
                                                       max_selections_per_seq,
                                                       pad_value=0)
        masked_lm_positions = tf.cast(masked_lm_positions, tf.int32)
        num_predictions = int(tf.shape(masked_lm_positions)[1])

        # Test transformer inputs.
        self.assertEqual(num_predictions, max_selections_per_seq)
        expected_word_ids = np.array([
            # [CLS] hello [SEP] world [SEP]
            [2, 5, 3, 6, 3, 0, 0, 0, 0, 0],
            # [CLS] nice movie [SEP] great actors [SEP]
            [2, 7, 8, 3, 9, 10, 3, 0, 0, 0],
            # [CLS] brown fox [SEP] lazy dog [SEP]
            [2, 11, 12, 3, 13, 14, 3, 0, 0, 0]
        ])
        for i in range(batch_size):
            for j in range(num_predictions):
                k = int(masked_lm_positions[i, j])
                if k != 0:
                    expected_word_ids[i, k] = 4  # [MASK]
        self.assertAllEqual(input_word_ids, expected_word_ids)

        # Call the MLM head of the Transformer encoder.
        mlm_inputs = dict(
            input_word_ids=input_word_ids,
            input_mask=input_mask,
            input_type_ids=input_type_ids,
            masked_lm_positions=masked_lm_positions,
        )
        mlm_outputs = encoder.mlm(mlm_inputs)
        self.assertEqual(mlm_outputs["pooled_output"].shape,
                         (batch_size, hidden_size))
        self.assertEqual(mlm_outputs["sequence_output"].shape,
                         (batch_size, seq_length, hidden_size))
        self.assertEqual(mlm_outputs["mlm_logits"].shape,
                         (batch_size, num_predictions, vocab_size))
        self.assertLen(mlm_outputs["encoder_outputs"], num_hidden_layers)

        # A real trainer would now compute the loss of mlm_logits
        # trying to predict the masked_ids.
        del masked_ids  # Unused.
    def call(
        self,
        text,
        text_pair=None,
        padding=None,
        truncation=None,
        max_length=None,
        pad_to_multiple_of=None,
        return_token_type_ids=None,
        return_attention_mask=None,
    ):
        if padding is None:
            padding = self.padding
        if padding not in ("longest", "max_length"):
            raise ValueError(
                "Padding must be either 'longest' or 'max_length'!")
        if max_length is not None and text_pair is not None:
            # Because we have to instantiate a Trimmer to do it properly
            raise ValueError(
                "max_length cannot be overridden at call time when truncating paired texts!"
            )
        if max_length is None:
            max_length = self.max_length
        if truncation is None:
            truncation = self.truncation
        if pad_to_multiple_of is None:
            pad_to_multiple_of = self.pad_to_multiple_of
        if return_token_type_ids is None:
            return_token_type_ids = self.return_token_type_ids
        if return_attention_mask is None:
            return_attention_mask = self.return_attention_mask
        if not isinstance(text, tf.Tensor):
            text = tf.convert_to_tensor(text)
        if text_pair is not None and not isinstance(text_pair, tf.Tensor):
            text_pair = tf.convert_to_tensor(text_pair)
        if text_pair is not None:
            if text.shape.rank > 1:
                raise ValueError(
                    "text argument should not be multidimensional when a text pair is supplied!"
                )
            if text_pair.shape.rank > 1:
                raise ValueError("text_pair should not be multidimensional!")
        if text.shape.rank == 2:
            text, text_pair = text[:, 0], text[:, 1]
        text = self.unpaired_tokenize(text)
        if text_pair is None:  # Unpaired text
            if truncation:
                text = text[:, :max_length -
                            2]  # Allow room for special tokens
            input_ids, token_type_ids = combine_segments(
                (text, ),
                start_of_sequence_id=self.cls_token_id,
                end_of_segment_id=self.sep_token_id)
        else:  # Paired text
            text_pair = self.unpaired_tokenize(text_pair)
            if truncation:
                text, text_pair = self.paired_trimmer.trim([text, text_pair])
            input_ids, token_type_ids = combine_segments(
                (text, text_pair),
                start_of_sequence_id=self.cls_token_id,
                end_of_segment_id=self.sep_token_id)
        if padding == "longest":
            pad_length = input_ids.bounding_shape(axis=1)
            if pad_to_multiple_of is not None:
                # No ceiling division in tensorflow, so we negate floordiv instead
                pad_length = pad_to_multiple_of * (
                    -tf.math.floordiv(-pad_length, pad_to_multiple_of))
        else:
            pad_length = max_length

        input_ids, attention_mask = pad_model_inputs(
            input_ids, max_seq_length=pad_length, pad_value=self.pad_token_id)
        output = {"input_ids": input_ids}
        if return_attention_mask:
            output["attention_mask"] = attention_mask
        if return_token_type_ids:
            token_type_ids, _ = pad_model_inputs(token_type_ids,
                                                 max_seq_length=pad_length,
                                                 pad_value=self.pad_token_id)
            output["token_type_ids"] = token_type_ids
        return output