def _text_to_ids(self, Xs, Y=None, pad_token=None): Xs = self._format_for_encoding(Xs) if self.config.chunk_long_sequences and len(Xs) == 1: # can only chunk single sequence inputs chunk_size = self.config.max_length - 2 step_size = chunk_size // 3 encoded = self.text_encoder.encode_multi_input( Xs, Y=Y, max_length=sys.maxsize, pad_token=(pad_token or self.config.pad_token), ) length = len(encoded.token_ids) assert length == len(encoded.token_ids) starts = list(range(0, length, step_size)) field_starts_and_ends = dict() for field in EncodedOutput._fields: field_value = getattr(encoded, field) if field_value is not None: field_starts_and_ends[field] = (field_value[0], field_value[-1]) for start in starts: d = dict() end = start + chunk_size for field in EncodedOutput._fields: field_value = getattr(encoded, field) if field_value is not None: fv = field_value[start:end] if self.config.add_eos_bos_to_chunk: start_token, end_token = field_starts_and_ends[ field] if fv[0] != start_token: fv = [start_token] + fv if fv[-1] != end_token: fv = fv + [end_token] d[field] = fv yield self._array_format(EncodedOutput(**d), pad_token=pad_token) else: encoder_out = self.text_encoder.encode_multi_input( Xs, Y=Y, max_length=self.config.max_length, pad_token=(pad_token or self.config.pad_token), ) d = dict() for field in EncodedOutput._fields: field_value = getattr(encoder_out, field) if field_value is not None: d[field] = field_value yield self._array_format(EncodedOutput(**d), pad_token=(pad_token or self.config.pad_token))
def _encode(self, texts, labels=None): """ Convert a batch of raw text to a batch of byte-pair encoded token indices. """ self._lazy_init() batch_tokens = [] batch_token_idxs = [] batch_label_idxs = [] batch_character_locs = [] label = None for i, text in enumerate(texts): if labels is not None: label = labels[i] subtokens, token_idxs = self.tokenizer.tokenize(text) subtoken_locs = [l[1] for l in token_idxs] batch_tokens.append(subtokens) batch_token_idxs.append(self.tokenizer.convert_tokens_to_ids(subtokens)) batch_character_locs.append(subtoken_locs) if labels is not None: batch_label_idxs.append([label] * len(subtokens)) return EncodedOutput( token_ids=batch_token_idxs, tokens=batch_tokens, labels=batch_label_idxs, char_locs=batch_character_locs, )
def test_gpt2_featurize(self): model = Classifier(base_model=GPT2) def dataset_encoded(): yield {"tokens": arr_encoded.token_ids, "mask": arr_encoded.mask} def get_input_fn(): types, shapes = model.input_pipeline.feed_shape_type_def() tf_dataset = Dataset.from_generator(dataset_encoded, types[0], shapes[0]) return tf_dataset.batch(1) encoded = model.input_pipeline.text_encoder._encode(self.TEST_DATA) encoded = EncodedOutput(token_ids=encoded.token_ids[0]) estimator, hooks = model.get_estimator(force_build_lm=False) predict = estimator.predict( input_fn=get_input_fn, predict_keys=[PredictMode.SEQUENCE], hooks=hooks ) arr_encoded = model.input_pipeline._array_format(encoded) sequence_features = next(predict)[PredictMode.SEQUENCE] np.testing.assert_allclose( sequence_features[:len(encoded.token_ids),:], np.load( os.path.join( DIRECTORY, 'data/test-gpt2-activations.npy' ) ), atol=1e-1 )
def _encode(self, texts, labels=None): """ Convert a batch of raw text to a batch of byte-pair encoded token indices. """ self._lazy_init() batch_tokens = [] batch_token_idxs = [] batch_label_idxs = [] batch_character_locs = [] label = None for i, text in enumerate(texts): if labels is not None: label = labels[i] raw_text = text.lower() tokens = NLP(_text_standardize(text)) subtokens = [] subtoken_idxs = [] tok_pos = [] token_start = 0 for j, token in enumerate(tokens): bpe_toks = self.bpe(token.text).split(' ') try: if token.text.strip(): token_start = raw_text.index(token.text.strip(), token_start) except ValueError: # text_standardization oddity continue subtokens.extend(bpe_toks) subtoken_idxs.extend([ self.encoder.get(SUBS.get(t, t), self.UNK_IDX) for t in bpe_toks ]) assert len("".join(bpe_toks).replace("</w>", "")) == len( token.text.replace(' ', '')) subtoken_positions = np.cumsum( [len(tok.replace("</w>", '')) for tok in bpe_toks]) + token_start token_start += len(token.text.strip()) tok_pos.extend(subtoken_positions) batch_tokens.append(subtokens) batch_token_idxs.append(subtoken_idxs) batch_character_locs.append(tok_pos) if labels is not None: batch_label_idxs.append([label] * len(subtoken_idxs)) return EncodedOutput( token_ids=batch_token_idxs, tokens=batch_tokens, labels=batch_label_idxs, char_locs=batch_character_locs, )
def _text_to_ids(self, Xs, Y=None, pad_token=None): Xs = self._format_for_encoding(Xs) if self.config.chunk_long_sequences and len(Xs) == 1: # can only chunk single sequence inputs chunk_size = self.config.max_length - 2 step_size = chunk_size // 3 encoded = self.text_encoder.encode_multi_input( Xs, Y=Y, max_length=sys.maxsize, pad_token=(pad_token or self.config.pad_token)) length = len(encoded.token_ids) starts = list(range(0, length, step_size)) for start in starts: d = dict() end = start + chunk_size for field in EncodedOutput._fields: field_value = getattr(encoded, field) if field_value is not None: d[field] = field_value[start:end] yield self._array_format(EncodedOutput(**d), pad_token=pad_token) else: encoder_out = self.text_encoder.encode_multi_input( Xs, Y=Y, max_length=self.config.max_length, pad_token=(pad_token or self.config.pad_token)) yield self._array_format(encoder_out, pad_token=(pad_token or self.config.pad_token))
def _encode(self, texts, labels=None, stochastic=False): """ Convert a batch of raw text to a batch of byte-pair encoded token indices. """ self._lazy_init() batch_tokens = [] batch_token_idxs = [] batch_label_idxs = [] batch_character_locs = [] batch_char_starts = [] label = None for i, text in enumerate(texts): text = text.replace(WEIRD_SPM_CHAR, "_") if labels is not None: label = labels[i] subtokens = [] subtoken_idxs = [] tok_pos = [] char_starts = [] token_start = 0 if stochastic: encoded = self.encoder.sample_encode_as_pieces(text, -1, 0.1) else: encoded = self.encoder.encode_as_pieces(text) for j, token in enumerate(encoded): subtokens.append(token) subtoken_idxs.append(self.encoder.piece_to_id(token)) raw_text = token.replace(WEIRD_SPM_CHAR, "") token_start_temp = text.find(raw_text, token_start) if token_start_temp == -1: LOGGER.warning( "SentencePiece produced a token {} not found in the original string {}" .format(raw_text, text)) else: token_start = token_start_temp tok_pos.append(token_start + len(raw_text)) char_starts.append(token_start) token_start += len(raw_text) batch_tokens.append(subtokens) batch_token_idxs.append(subtoken_idxs) batch_character_locs.append(tok_pos) batch_char_starts.append(char_starts) if labels is not None: batch_label_idxs.append([label] * len(subtoken_idxs)) return EncodedOutput( token_ids=batch_token_idxs, tokens=batch_tokens, labels=batch_label_idxs, char_locs=batch_character_locs, char_starts=batch_char_starts, )
def generate_text(self, seed_text="", max_length=None, use_extra_toks=None): """ Performs a prediction on the Language modeling objective given some seed text. It uses a noisy greedy decoding. Temperature parameter for decoding is set in the config. :param max_length: The maximum length to decode to. :param seed_text: Defaults to the empty string. This will form the starting point to begin modelling :return: A string containing the generated text. """ if use_extra_toks is None: use_extra_toks = self._trained def dataset_encoded(): while not dataset_encoded.finished: yield {"tokens": arr_encoded.token_ids, "mask": arr_encoded.mask} dataset_encoded.finished = False def get_input_fn(): types, shapes = self.input_pipeline.feed_shape_type_def() tf_dataset = Dataset.from_generator(dataset_encoded, types[0], shapes[0]) return tf_dataset.batch(1) self.config.use_extra_toks = use_extra_toks encoded = self.input_pipeline.text_encoder._encode([seed_text]) if encoded.token_ids == [] and not use_extra_toks: raise ValueError( "If you are not using the extra tokens, you must provide some non-empty seed text" ) start = [self.input_pipeline.text_encoder.start] if use_extra_toks else [] token_ids = start if encoded.token_ids is not None and len(encoded.token_ids): token_ids += encoded.token_ids[0] encoded = EncodedOutput(token_ids=token_ids) estimator, hooks = self.get_estimator(force_build_lm=True) predict = estimator.predict( input_fn=get_input_fn, predict_keys=[PredictMode.GENERATE_TEXT], hooks=hooks ) EOS = self.input_pipeline.text_encoder.clf_token with warnings.catch_warnings(): warnings.filterwarnings("ignore") for i in range( len(encoded.token_ids) - 1, (max_length or self.config.max_length) - 2 ): arr_encoded = self.input_pipeline._array_format(encoded) class_idx = next(predict)[PredictMode.GENERATE_TEXT] encoded.token_ids.append(class_idx[i]) if encoded.token_ids[-1] == EOS: break dataset_encoded.finished = True del self.config["use_extra_toks"] return self.input_pipeline.text_encoder.decode(encoded.token_ids)
def _encode(self, texts, labels=None): """ Convert a batch of raw text to a batch of byte-pair encoded token indices. """ self._lazy_init() batch_tokens = [] batch_token_idxs = [] batch_label_idxs = [] batch_character_locs = [] label = None for i, text in enumerate(texts): if labels is not None: label = labels[i] subtokens = [] subtoken_idxs = [] tok_pos = [] token_start = 0 tokens = re.findall(self.pat, text) for j, token in enumerate(tokens): encoded_token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_toks = self.bpe(encoded_token).split(' ') try: if token.strip(): token_start = text.index(token, token_start) except ValueError: # text_standardization oddity traceback.print_exc() continue subtokens.extend(bpe_toks) subtoken_idxs.extend([ self.encoder.get(t, self.UNK_IDX) for t in bpe_toks ]) subtoken_positions = np.cumsum([len(tok) for tok in bpe_toks]) + token_start token_start += len(token.strip()) tok_pos.extend(subtoken_positions) batch_tokens.append(subtokens) batch_token_idxs.append(subtoken_idxs) batch_character_locs.append(tok_pos) if labels is not None: batch_label_idxs.append([label] * len(subtoken_idxs)) return EncodedOutput( token_ids=batch_token_idxs, tokens=batch_tokens, labels=batch_label_idxs, char_locs=batch_character_locs, )
def _encode(self, texts, labels=None): """ Convert a batch of raw text to a batch of byte-pair encoded token indices. """ self._lazy_init() batch_tokens = [] batch_token_idxs = [] batch_label_idxs = [] batch_char_ends = [] batch_char_starts = [] label = None offset = 0 skipped = 0 for i, text in enumerate(texts): if labels is not None: label = labels[i] char_ends = [] subtokens, _, token_char_ends, starts = self.tokenizer.tokenize( text) if not subtokens: offset += len(text) # for spans that are just whitespace skipped += 1 continue i -= skipped char_ends.extend(token_char_ends) subtoken_idxs = self.tokenizer.convert_tokens_to_ids(subtokens) batch_tokens.append(subtokens) batch_token_idxs.append(subtoken_idxs) batch_char_ends.append(char_ends) batch_char_starts.append(starts) if labels is not None: batch_label_idxs.append([label] * len(subtokens)) return EncodedOutput( token_ids=batch_token_idxs, tokens=batch_tokens, labels=batch_label_idxs, char_locs=batch_char_ends, char_starts=batch_char_starts, )
def _encode(self, texts, labels=None): """ Convert a sample of raw text to a list of byte-pair encoded token indices. """ self._lazy_init() batch_tokens = [] batch_token_idxs = [] batch_label_idxs = [] batch_char_ends = [] # to account for the fact that some BPEs have different lengths than their original tokens # (e.g. special characters such as bullets) batch_char_starts = [] label = None skipped = 0 for i, text in enumerate(texts): # text = one label span if labels is not None: label = labels[i] subtokens = [] subtoken_idxs = [] char_ends = [] char_starts = [] token_start = 0 tokens = re.findall(self.pat, text) if not tokens: skipped += 1 continue i -= skipped for j, token in enumerate(tokens): encoded_token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) bpe_toks = self.bpe(encoded_token).split(" ") try: if token.strip(): token_start = text.index(token, token_start) except ValueError: # text_standardization oddity traceback.print_exc() continue subtokens.extend(bpe_toks) subtoken_idxs.extend( [self.encoder.get(t, self.UNK_IDX) for t in bpe_toks]) token_char_starts = [token_start] * len(bpe_toks) if np.sum([len(tok) for tok in bpe_toks]) > len(token): token_char_ends = ( np.asarray([len(token.strip()) for tok in bpe_toks]) + token_start) else: token_char_ends = ( np.cumsum([len(tok) for tok in bpe_toks]) + token_start) token_start += len(token.strip()) char_ends.extend(token_char_ends) char_starts.extend(token_char_starts) batch_tokens.append(subtokens) batch_token_idxs.append(subtoken_idxs) batch_char_ends.append(char_ends) batch_char_starts.append(char_starts) if labels is not None: batch_label_idxs.append([label] * len(subtoken_idxs)) return EncodedOutput( token_ids=batch_token_idxs, tokens=batch_tokens, labels=batch_label_idxs, char_locs=batch_char_ends, char_starts=batch_char_starts, )
def _encode(self, texts, labels=None, context=None): """ Convert a sample of raw text to a list of byte-pair encoded token indices. """ self._lazy_init() batch_tokens = [] batch_token_idxs = [] batch_label_idxs = [] batch_char_ends = ( [] ) # to account for the fact that some BPEs have different lengths than their original tokens (e.g. special characters such as bullets) batch_context = [] batch_char_starts = [] label = None offset = ( 0 ) # tracks offset between this fields' character_locs, which start at 0, and the 'start' keys in context which track the entire document (not just this field) skipped = 0 for i, text in enumerate(texts): # text = one label span if labels is not None: label = labels[i] subtokens = [] subtoken_idxs = [] char_ends = [] char_starts = [] token_start = 0 tokens = re.findall(self.pat, text) if not tokens: offset += len(text) # for spans that are just whitespace skipped += 1 continue i -= skipped for j, token in enumerate(tokens): encoded_token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) bpe_toks = self.bpe(encoded_token).split(" ") try: if token.strip(): token_start = text.index(token, token_start) except ValueError: # text_standardization oddity traceback.print_exc() continue subtokens.extend(bpe_toks) subtoken_idxs.extend( [self.encoder.get(t, self.UNK_IDX) for t in bpe_toks]) token_char_starts = [token_start] * len(bpe_toks) if np.sum([len(tok) for tok in bpe_toks]) > len(token): token_char_ends = ( np.asarray([len(token.strip()) for tok in bpe_toks]) + token_start) else: token_char_ends = ( np.cumsum([len(tok) for tok in bpe_toks]) + token_start) token_start += len(token.strip()) char_ends.extend(token_char_ends) char_starts.extend(token_char_starts) batch_tokens.append(subtokens) batch_token_idxs.append(subtoken_idxs) batch_char_ends.append(char_ends) batch_char_starts.append(char_starts) if labels is not None: batch_label_idxs.append([label] * len(subtoken_idxs)) # Context is tokenwise, so we need to duplicate contexts for each subtoken of a token, and to match length of labels if context is not None: text_context = self.line_up_context(context, batch_char_ends[i], batch_tokens[i], subtoken_idxs, offset) batch_context.extend(text_context) offset += batch_char_ends[i][-1] return EncodedOutput( token_ids=batch_token_idxs, tokens=batch_tokens, labels=batch_label_idxs, context=batch_context, char_locs=batch_char_ends, char_starts=batch_char_starts, )
def _text_to_ids(self, Xs, Y=None, pad_token=None, context=None): if context is None and self.config.use_auxiliary_info: context = Xs[0] Xs = Xs[1] Xs = self._format_for_encoding(Xs) if self.config.chunk_long_sequences and len(Xs) == 1: # can only chunk single sequence inputs chunk_size = self.config.max_length - 2 step_size = chunk_size // 3 encoded = self.text_encoder.encode_multi_input( Xs, Y=Y, max_length=sys.maxsize, pad_token=(pad_token or self.config.pad_token), context=context, ) if self.config.use_auxiliary_info: processed_context = np.squeeze( self._context_to_vector([encoded.context])) length = len(encoded.token_ids) assert length == len(encoded.token_ids) starts = list(range(0, length, step_size)) for start in starts: d = dict() end = start + chunk_size for field in EncodedOutput._fields: field_value = getattr(encoded, field) if field_value is not None: d[field] = field_value[start:end] if self.config.use_auxiliary_info: d["context"] = processed_context[ start:end] # forced since encoded is immutable' else: d['context'] = None yield self._array_format(EncodedOutput(**d), pad_token=pad_token) else: encoder_out = self.text_encoder.encode_multi_input( Xs, Y=Y, max_length=self.config.max_length, pad_token=(pad_token or self.config.pad_token), context=context, ) d = dict() for field in EncodedOutput._fields: field_value = getattr(encoder_out, field) if field_value is not None: d[field] = field_value if self.config.use_auxiliary_info: d["context"] = np.squeeze( self._context_to_vector([ encoder_out.context ])) # forced since encoded is immutable else: d["context"] = None yield self._array_format(EncodedOutput(**d), pad_token=(pad_token or self.config.pad_token))
def _encode(self, texts, labels=None): """ Convert a batch of raw text to a batch of byte-pair encoded token indices. """ self._lazy_init() batch_tokens = [] batch_token_idxs = [] batch_label_idxs = [] batch_char_ends = ( [] ) # to account for the fact that some BPEs have different lengths than their original tokens (e.g. special characters such as bullets) batch_char_starts = [] label = None skipped = 0 for i, text in enumerate(texts): if labels is not None: label = labels[i] raw_text = text.lower() # Only fine to apply this fix because it preserves character locations ftfy_text = uncurl_quotes(raw_text) tokens = NLP(_text_standardize(text)) if not tokens: skipped += 1 continue i -= skipped subtokens = [] subtoken_idxs = [] char_starts = [] char_ends = [] token_start = 0 for j, token in enumerate(tokens): bpe_toks = self.bpe(token.text).split(" ") try: if token.text.strip(): token_start = ftfy_text.index((token.text.strip()), token_start) except ValueError: warnings.warn( "Failed to find token `{}` in text.".format(token.text) ) continue subtokens.extend(bpe_toks) subtoken_idxs.extend( [self.encoder.get(SUBS.get(t, t), self.UNK_IDX) for t in bpe_toks] ) assert len("".join(bpe_toks).replace("</w>", "")) == len( token.text.replace(" ", "") ) if np.sum([len(tok.replace("</w>", "")) for tok in bpe_toks]) > len( token ): # the BPEs comprising a token are longer than the token itself token_char_ends = ( np.asarray([len(token.text.strip()) for tok in bpe_toks]) + token_start ) else: token_char_ends = ( np.cumsum([len(tok.replace("</w>", "")) for tok in bpe_toks]) + token_start ) token_char_starts = [token_start] + token_char_ends[:-1].tolist() token_start += len(token.text.strip()) char_ends.extend(token_char_ends) char_starts.extend(token_char_starts) batch_tokens.append(subtokens) batch_token_idxs.append(subtoken_idxs) batch_char_ends.append(char_ends) batch_char_starts.append(char_starts) if labels is not None: batch_label_idxs.append([label] * len(subtoken_idxs)) return EncodedOutput( token_ids=batch_token_idxs, tokens=batch_tokens, labels=batch_label_idxs, char_locs=batch_char_ends, char_starts=batch_char_starts, )