Exemplo n.º 1
0
    def _text_to_ids(self, Xs, Y=None, pad_token=None):
        Xs = self._format_for_encoding(Xs)
        if self.config.chunk_long_sequences and len(Xs) == 1:
            # can only chunk single sequence inputs

            chunk_size = self.config.max_length - 2

            step_size = chunk_size // 3

            encoded = self.text_encoder.encode_multi_input(
                Xs,
                Y=Y,
                max_length=sys.maxsize,
                pad_token=(pad_token or self.config.pad_token),
            )
            length = len(encoded.token_ids)
            assert length == len(encoded.token_ids)
            starts = list(range(0, length, step_size))
            field_starts_and_ends = dict()
            for field in EncodedOutput._fields:
                field_value = getattr(encoded, field)
                if field_value is not None:
                    field_starts_and_ends[field] = (field_value[0],
                                                    field_value[-1])

            for start in starts:
                d = dict()
                end = start + chunk_size

                for field in EncodedOutput._fields:
                    field_value = getattr(encoded, field)
                    if field_value is not None:
                        fv = field_value[start:end]
                        if self.config.add_eos_bos_to_chunk:
                            start_token, end_token = field_starts_and_ends[
                                field]
                            if fv[0] != start_token:
                                fv = [start_token] + fv
                            if fv[-1] != end_token:
                                fv = fv + [end_token]
                        d[field] = fv
                yield self._array_format(EncodedOutput(**d),
                                         pad_token=pad_token)
        else:
            encoder_out = self.text_encoder.encode_multi_input(
                Xs,
                Y=Y,
                max_length=self.config.max_length,
                pad_token=(pad_token or self.config.pad_token),
            )

            d = dict()
            for field in EncodedOutput._fields:
                field_value = getattr(encoder_out, field)
                if field_value is not None:
                    d[field] = field_value

            yield self._array_format(EncodedOutput(**d),
                                     pad_token=(pad_token
                                                or self.config.pad_token))
Exemplo n.º 2
0
    def _encode(self, texts, labels=None):
        """
        Convert a batch of raw text to a batch of byte-pair encoded token indices.
        """
        self._lazy_init()
        batch_tokens = []
        batch_token_idxs = []
        batch_label_idxs = []
        batch_character_locs = []
        label = None

        for i, text in enumerate(texts):
            if labels is not None:
                label = labels[i]

            subtokens, token_idxs = self.tokenizer.tokenize(text)
            subtoken_locs = [l[1] for l in token_idxs]

            batch_tokens.append(subtokens)
            batch_token_idxs.append(self.tokenizer.convert_tokens_to_ids(subtokens))
            batch_character_locs.append(subtoken_locs)
            if labels is not None:
                batch_label_idxs.append([label] * len(subtokens))

        return EncodedOutput(
            token_ids=batch_token_idxs,
            tokens=batch_tokens,
            labels=batch_label_idxs,
            char_locs=batch_character_locs,
        )
Exemplo n.º 3
0
    def test_gpt2_featurize(self):
        model = Classifier(base_model=GPT2)
        
        def dataset_encoded():
            yield {"tokens": arr_encoded.token_ids, "mask": arr_encoded.mask}

        def get_input_fn():
            types, shapes = model.input_pipeline.feed_shape_type_def()
            tf_dataset = Dataset.from_generator(dataset_encoded, types[0], shapes[0])
            return tf_dataset.batch(1)

        encoded = model.input_pipeline.text_encoder._encode(self.TEST_DATA)
        encoded = EncodedOutput(token_ids=encoded.token_ids[0])
        estimator, hooks = model.get_estimator(force_build_lm=False)
        predict = estimator.predict(
            input_fn=get_input_fn, predict_keys=[PredictMode.SEQUENCE], hooks=hooks
        )
        arr_encoded = model.input_pipeline._array_format(encoded)
        sequence_features = next(predict)[PredictMode.SEQUENCE]

        np.testing.assert_allclose(
            sequence_features[:len(encoded.token_ids),:],
            np.load(
                os.path.join(
                    DIRECTORY, 
                    'data/test-gpt2-activations.npy'
                )
            ),
            atol=1e-1
        )
Exemplo n.º 4
0
    def _encode(self, texts, labels=None):
        """
        Convert a batch of raw text to a batch of byte-pair encoded token indices.
        """
        self._lazy_init()
        batch_tokens = []
        batch_token_idxs = []
        batch_label_idxs = []
        batch_character_locs = []
        label = None

        for i, text in enumerate(texts):
            if labels is not None:
                label = labels[i]
            raw_text = text.lower()
            tokens = NLP(_text_standardize(text))
            subtokens = []
            subtoken_idxs = []
            tok_pos = []
            token_start = 0

            for j, token in enumerate(tokens):
                bpe_toks = self.bpe(token.text).split(' ')

                try:
                    if token.text.strip():
                        token_start = raw_text.index(token.text.strip(),
                                                     token_start)
                except ValueError:
                    # text_standardization oddity
                    continue

                subtokens.extend(bpe_toks)
                subtoken_idxs.extend([
                    self.encoder.get(SUBS.get(t, t), self.UNK_IDX)
                    for t in bpe_toks
                ])

                assert len("".join(bpe_toks).replace("</w>", "")) == len(
                    token.text.replace(' ', ''))
                subtoken_positions = np.cumsum(
                    [len(tok.replace("</w>", ''))
                     for tok in bpe_toks]) + token_start

                token_start += len(token.text.strip())

                tok_pos.extend(subtoken_positions)

            batch_tokens.append(subtokens)
            batch_token_idxs.append(subtoken_idxs)
            batch_character_locs.append(tok_pos)
            if labels is not None:
                batch_label_idxs.append([label] * len(subtoken_idxs))

        return EncodedOutput(
            token_ids=batch_token_idxs,
            tokens=batch_tokens,
            labels=batch_label_idxs,
            char_locs=batch_character_locs,
        )
Exemplo n.º 5
0
    def _text_to_ids(self, Xs, Y=None, pad_token=None):
        Xs = self._format_for_encoding(Xs)
        if self.config.chunk_long_sequences and len(Xs) == 1:
            # can only chunk single sequence inputs
            chunk_size = self.config.max_length - 2
            step_size = chunk_size // 3
            encoded = self.text_encoder.encode_multi_input(
                Xs,
                Y=Y,
                max_length=sys.maxsize,
                pad_token=(pad_token or self.config.pad_token))
            length = len(encoded.token_ids)
            starts = list(range(0, length, step_size))
            for start in starts:
                d = dict()
                end = start + chunk_size
                for field in EncodedOutput._fields:
                    field_value = getattr(encoded, field)
                    if field_value is not None:
                        d[field] = field_value[start:end]
                yield self._array_format(EncodedOutput(**d),
                                         pad_token=pad_token)
        else:
            encoder_out = self.text_encoder.encode_multi_input(
                Xs,
                Y=Y,
                max_length=self.config.max_length,
                pad_token=(pad_token or self.config.pad_token))

            yield self._array_format(encoder_out,
                                     pad_token=(pad_token
                                                or self.config.pad_token))
Exemplo n.º 6
0
    def _encode(self, texts, labels=None, stochastic=False):
        """
        Convert a batch of raw text to a batch of byte-pair encoded token indices.
        """
        self._lazy_init()

        batch_tokens = []
        batch_token_idxs = []
        batch_label_idxs = []
        batch_character_locs = []
        batch_char_starts = []
        label = None

        for i, text in enumerate(texts):
            text = text.replace(WEIRD_SPM_CHAR, "_")
            if labels is not None:
                label = labels[i]

            subtokens = []
            subtoken_idxs = []
            tok_pos = []
            char_starts = []
            token_start = 0
            if stochastic:
                encoded = self.encoder.sample_encode_as_pieces(text, -1, 0.1)
            else:
                encoded = self.encoder.encode_as_pieces(text)

            for j, token in enumerate(encoded):
                subtokens.append(token)
                subtoken_idxs.append(self.encoder.piece_to_id(token))
                raw_text = token.replace(WEIRD_SPM_CHAR, "")
                token_start_temp = text.find(raw_text, token_start)
                if token_start_temp == -1:
                    LOGGER.warning(
                        "SentencePiece produced a token {} not found in the original string {}"
                        .format(raw_text, text))
                else:
                    token_start = token_start_temp
                tok_pos.append(token_start + len(raw_text))
                char_starts.append(token_start)
                token_start += len(raw_text)

            batch_tokens.append(subtokens)
            batch_token_idxs.append(subtoken_idxs)
            batch_character_locs.append(tok_pos)
            batch_char_starts.append(char_starts)
            if labels is not None:
                batch_label_idxs.append([label] * len(subtoken_idxs))

        return EncodedOutput(
            token_ids=batch_token_idxs,
            tokens=batch_tokens,
            labels=batch_label_idxs,
            char_locs=batch_character_locs,
            char_starts=batch_char_starts,
        )
Exemplo n.º 7
0
    def generate_text(self, seed_text="", max_length=None, use_extra_toks=None):
        """
        Performs a prediction on the Language modeling objective given some seed text. It uses a noisy greedy decoding.
        Temperature parameter for decoding is set in the config.
        :param max_length: The maximum length to decode to.
        :param seed_text: Defaults to the empty string. This will form the starting point to begin modelling
        :return: A string containing the generated text.
        """
        if use_extra_toks is None:
            use_extra_toks = self._trained
    
        def dataset_encoded():
            while not dataset_encoded.finished:
                yield {"tokens": arr_encoded.token_ids, "mask": arr_encoded.mask}

        dataset_encoded.finished = False

        def get_input_fn():
            types, shapes = self.input_pipeline.feed_shape_type_def()
            tf_dataset = Dataset.from_generator(dataset_encoded, types[0], shapes[0])
            return tf_dataset.batch(1)

        self.config.use_extra_toks = use_extra_toks
        encoded = self.input_pipeline.text_encoder._encode([seed_text])
        if encoded.token_ids == [] and not use_extra_toks:
            raise ValueError(
                "If you are not using the extra tokens, you must provide some non-empty seed text"
            )
        start = [self.input_pipeline.text_encoder.start] if use_extra_toks else []
        token_ids = start 
        if encoded.token_ids is not None and len(encoded.token_ids):
            token_ids += encoded.token_ids[0]
        encoded = EncodedOutput(token_ids=token_ids)

        estimator, hooks = self.get_estimator(force_build_lm=True)
        predict = estimator.predict(
            input_fn=get_input_fn, predict_keys=[PredictMode.GENERATE_TEXT], hooks=hooks
        )

        EOS = self.input_pipeline.text_encoder.clf_token
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore")
            for i in range(
                len(encoded.token_ids) - 1, (max_length or self.config.max_length) - 2
            ):
                arr_encoded = self.input_pipeline._array_format(encoded)
                class_idx = next(predict)[PredictMode.GENERATE_TEXT]
                encoded.token_ids.append(class_idx[i])
                if encoded.token_ids[-1] == EOS:
                    break
            dataset_encoded.finished = True

        del self.config["use_extra_toks"]

        return self.input_pipeline.text_encoder.decode(encoded.token_ids)
Exemplo n.º 8
0
    def _encode(self, texts, labels=None):
        """
        Convert a batch of raw text to a batch of byte-pair encoded token indices.
        """
        self._lazy_init()

        batch_tokens = []
        batch_token_idxs = []
        batch_label_idxs = []
        batch_character_locs = []
        label = None

        for i, text in enumerate(texts):
            if labels is not None:
                label = labels[i]

            subtokens = []
            subtoken_idxs = []
            tok_pos = []
            token_start = 0

            tokens = re.findall(self.pat, text)
            for j, token in enumerate(tokens):
                encoded_token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
                bpe_toks = self.bpe(encoded_token).split(' ')
                try:
                    if token.strip():
                        token_start = text.index(token, token_start)
                except ValueError:
                    # text_standardization oddity
                    traceback.print_exc()
                    continue

                subtokens.extend(bpe_toks)
                subtoken_idxs.extend([
                    self.encoder.get(t, self.UNK_IDX)
                    for t in bpe_toks
                ])
                subtoken_positions = np.cumsum([len(tok) for tok in bpe_toks]) + token_start
                token_start += len(token.strip())
                tok_pos.extend(subtoken_positions)

            batch_tokens.append(subtokens)
            batch_token_idxs.append(subtoken_idxs)
            batch_character_locs.append(tok_pos)
            if labels is not None:
                batch_label_idxs.append([label] * len(subtoken_idxs))

        return EncodedOutput(
            token_ids=batch_token_idxs,
            tokens=batch_tokens,
            labels=batch_label_idxs,
            char_locs=batch_character_locs,
        )
Exemplo n.º 9
0
    def _encode(self, texts, labels=None):
        """
        Convert a batch of raw text to a batch of byte-pair encoded token indices.
        """
        self._lazy_init()
        batch_tokens = []
        batch_token_idxs = []
        batch_label_idxs = []
        batch_char_ends = []
        batch_char_starts = []
        label = None
        offset = 0

        skipped = 0
        for i, text in enumerate(texts):
            if labels is not None:
                label = labels[i]
            char_ends = []

            subtokens, _, token_char_ends, starts = self.tokenizer.tokenize(
                text)
            if not subtokens:
                offset += len(text)  # for spans that are just whitespace
                skipped += 1
                continue
            i -= skipped

            char_ends.extend(token_char_ends)
            subtoken_idxs = self.tokenizer.convert_tokens_to_ids(subtokens)
            batch_tokens.append(subtokens)
            batch_token_idxs.append(subtoken_idxs)
            batch_char_ends.append(char_ends)
            batch_char_starts.append(starts)
            if labels is not None:
                batch_label_idxs.append([label] * len(subtokens))

        return EncodedOutput(
            token_ids=batch_token_idxs,
            tokens=batch_tokens,
            labels=batch_label_idxs,
            char_locs=batch_char_ends,
            char_starts=batch_char_starts,
        )
Exemplo n.º 10
0
    def _encode(self, texts, labels=None):
        """
        Convert a sample of raw text to a list of byte-pair encoded token indices.
        """
        self._lazy_init()
        batch_tokens = []
        batch_token_idxs = []
        batch_label_idxs = []
        batch_char_ends = []
        # to account for the fact that some BPEs have different lengths than their original tokens
        # (e.g. special characters such as bullets)
        batch_char_starts = []
        label = None

        skipped = 0
        for i, text in enumerate(texts):  # text = one label span
            if labels is not None:
                label = labels[i]

            subtokens = []
            subtoken_idxs = []
            char_ends = []
            char_starts = []
            token_start = 0

            tokens = re.findall(self.pat, text)
            if not tokens:
                skipped += 1
                continue
            i -= skipped
            for j, token in enumerate(tokens):
                encoded_token = "".join(self.byte_encoder[b]
                                        for b in token.encode("utf-8"))
                bpe_toks = self.bpe(encoded_token).split(" ")
                try:
                    if token.strip():
                        token_start = text.index(token, token_start)
                except ValueError:
                    # text_standardization oddity
                    traceback.print_exc()
                    continue

                subtokens.extend(bpe_toks)
                subtoken_idxs.extend(
                    [self.encoder.get(t, self.UNK_IDX) for t in bpe_toks])

                token_char_starts = [token_start] * len(bpe_toks)

                if np.sum([len(tok) for tok in bpe_toks]) > len(token):
                    token_char_ends = (
                        np.asarray([len(token.strip())
                                    for tok in bpe_toks]) + token_start)
                else:
                    token_char_ends = (
                        np.cumsum([len(tok)
                                   for tok in bpe_toks]) + token_start)

                token_start += len(token.strip())
                char_ends.extend(token_char_ends)
                char_starts.extend(token_char_starts)

            batch_tokens.append(subtokens)
            batch_token_idxs.append(subtoken_idxs)
            batch_char_ends.append(char_ends)
            batch_char_starts.append(char_starts)
            if labels is not None:
                batch_label_idxs.append([label] * len(subtoken_idxs))

        return EncodedOutput(
            token_ids=batch_token_idxs,
            tokens=batch_tokens,
            labels=batch_label_idxs,
            char_locs=batch_char_ends,
            char_starts=batch_char_starts,
        )
Exemplo n.º 11
0
    def _encode(self, texts, labels=None, context=None):
        """
        Convert a sample of raw text to a list of byte-pair encoded token indices.
        """
        self._lazy_init()
        batch_tokens = []
        batch_token_idxs = []
        batch_label_idxs = []
        batch_char_ends = (
            []
        )  # to account for the fact that some BPEs have different lengths than their original tokens (e.g. special characters such as bullets)
        batch_context = []
        batch_char_starts = []
        label = None
        offset = (
            0
        )  # tracks offset between this fields' character_locs, which start at 0, and the 'start' keys in context which track the entire document (not just this field)

        skipped = 0
        for i, text in enumerate(texts):  # text = one label span
            if labels is not None:
                label = labels[i]

            subtokens = []
            subtoken_idxs = []
            char_ends = []
            char_starts = []
            token_start = 0

            tokens = re.findall(self.pat, text)
            if not tokens:
                offset += len(text)  # for spans that are just whitespace
                skipped += 1
                continue
            i -= skipped
            for j, token in enumerate(tokens):
                encoded_token = "".join(self.byte_encoder[b]
                                        for b in token.encode("utf-8"))
                bpe_toks = self.bpe(encoded_token).split(" ")
                try:
                    if token.strip():
                        token_start = text.index(token, token_start)
                except ValueError:
                    # text_standardization oddity
                    traceback.print_exc()
                    continue

                subtokens.extend(bpe_toks)
                subtoken_idxs.extend(
                    [self.encoder.get(t, self.UNK_IDX) for t in bpe_toks])

                token_char_starts = [token_start] * len(bpe_toks)

                if np.sum([len(tok) for tok in bpe_toks]) > len(token):
                    token_char_ends = (
                        np.asarray([len(token.strip())
                                    for tok in bpe_toks]) + token_start)
                else:
                    token_char_ends = (
                        np.cumsum([len(tok)
                                   for tok in bpe_toks]) + token_start)

                token_start += len(token.strip())
                char_ends.extend(token_char_ends)
                char_starts.extend(token_char_starts)

            batch_tokens.append(subtokens)
            batch_token_idxs.append(subtoken_idxs)
            batch_char_ends.append(char_ends)
            batch_char_starts.append(char_starts)
            if labels is not None:
                batch_label_idxs.append([label] * len(subtoken_idxs))

            # Context is tokenwise, so we need to duplicate contexts for each subtoken of a token, and to match length of labels
            if context is not None:
                text_context = self.line_up_context(context,
                                                    batch_char_ends[i],
                                                    batch_tokens[i],
                                                    subtoken_idxs, offset)
                batch_context.extend(text_context)
                offset += batch_char_ends[i][-1]

        return EncodedOutput(
            token_ids=batch_token_idxs,
            tokens=batch_tokens,
            labels=batch_label_idxs,
            context=batch_context,
            char_locs=batch_char_ends,
            char_starts=batch_char_starts,
        )
Exemplo n.º 12
0
    def _text_to_ids(self, Xs, Y=None, pad_token=None, context=None):
        if context is None and self.config.use_auxiliary_info:
            context = Xs[0]
            Xs = Xs[1]
        Xs = self._format_for_encoding(Xs)
        if self.config.chunk_long_sequences and len(Xs) == 1:
            # can only chunk single sequence inputs
            chunk_size = self.config.max_length - 2
            step_size = chunk_size // 3
            encoded = self.text_encoder.encode_multi_input(
                Xs,
                Y=Y,
                max_length=sys.maxsize,
                pad_token=(pad_token or self.config.pad_token),
                context=context,
            )
            if self.config.use_auxiliary_info:
                processed_context = np.squeeze(
                    self._context_to_vector([encoded.context]))

            length = len(encoded.token_ids)
            assert length == len(encoded.token_ids)
            starts = list(range(0, length, step_size))
            for start in starts:
                d = dict()
                end = start + chunk_size
                for field in EncodedOutput._fields:
                    field_value = getattr(encoded, field)
                    if field_value is not None:
                        d[field] = field_value[start:end]
                if self.config.use_auxiliary_info:
                    d["context"] = processed_context[
                        start:end]  # forced since encoded is immutable'
                else:
                    d['context'] = None

                yield self._array_format(EncodedOutput(**d),
                                         pad_token=pad_token)
        else:
            encoder_out = self.text_encoder.encode_multi_input(
                Xs,
                Y=Y,
                max_length=self.config.max_length,
                pad_token=(pad_token or self.config.pad_token),
                context=context,
            )

            d = dict()
            for field in EncodedOutput._fields:
                field_value = getattr(encoder_out, field)
                if field_value is not None:
                    d[field] = field_value
            if self.config.use_auxiliary_info:
                d["context"] = np.squeeze(
                    self._context_to_vector([
                        encoder_out.context
                    ]))  # forced since encoded is immutable
            else:
                d["context"] = None

            yield self._array_format(EncodedOutput(**d),
                                     pad_token=(pad_token
                                                or self.config.pad_token))
Exemplo n.º 13
0
    def _encode(self, texts, labels=None):
        """
        Convert a batch of raw text to a batch of byte-pair encoded token indices.
        """

        self._lazy_init()
        batch_tokens = []
        batch_token_idxs = []
        batch_label_idxs = []
        batch_char_ends = (
            []
        )  # to account for the fact that some BPEs have different lengths than their original tokens (e.g. special characters such as bullets)
        batch_char_starts = []
        label = None
        skipped = 0
        for i, text in enumerate(texts):
            if labels is not None:
                label = labels[i]

            raw_text = text.lower()
            
            # Only fine to apply this fix because it preserves character locations
            ftfy_text = uncurl_quotes(raw_text)
            tokens = NLP(_text_standardize(text))
            if not tokens:
                skipped += 1
                continue
            i -= skipped
            subtokens = []
            subtoken_idxs = []
            char_starts = []
            char_ends = []
            token_start = 0

            for j, token in enumerate(tokens):
                bpe_toks = self.bpe(token.text).split(" ")

                try:
                    if token.text.strip():
                        token_start = ftfy_text.index((token.text.strip()), token_start)
                except ValueError:
                    warnings.warn(
                        "Failed to find token `{}` in text.".format(token.text)
                    )
                    continue

                subtokens.extend(bpe_toks)
                subtoken_idxs.extend(
                    [self.encoder.get(SUBS.get(t, t), self.UNK_IDX) for t in bpe_toks]
                )

                assert len("".join(bpe_toks).replace("</w>", "")) == len(
                    token.text.replace(" ", "")
                )

                if np.sum([len(tok.replace("</w>", "")) for tok in bpe_toks]) > len(
                    token
                ):  # the BPEs comprising a token are longer than the token itself
                    token_char_ends = (
                        np.asarray([len(token.text.strip()) for tok in bpe_toks])
                        + token_start
                    )
                else:
                    token_char_ends = (
                        np.cumsum([len(tok.replace("</w>", "")) for tok in bpe_toks])
                        + token_start
                    )
                
                token_char_starts = [token_start] + token_char_ends[:-1].tolist()
                token_start += len(token.text.strip())
                char_ends.extend(token_char_ends)
                char_starts.extend(token_char_starts)

            batch_tokens.append(subtokens)
            batch_token_idxs.append(subtoken_idxs)
            batch_char_ends.append(char_ends)
            batch_char_starts.append(char_starts)
            if labels is not None:
                batch_label_idxs.append([label] * len(subtoken_idxs))

        return EncodedOutput(
            token_ids=batch_token_idxs,
            tokens=batch_tokens,
            labels=batch_label_idxs,
            char_locs=batch_char_ends,
            char_starts=batch_char_starts,
        )