def _predict_chars( model: tf.keras.Sequential, sp: spm.SentencePieceProcessor, start_string: str, store: _BaseConfig, ) -> str: """ Evaluation step (generating text using the learned model). Args: model: tf.keras.Sequential model sp: SentencePiece tokenizer start_string: string to bootstrap model store: our config object Returns: Yields line of text per iteration """ # Converting our start string to numbers (vectorizing) input_eval = sp.EncodeAsIds(start_string) input_eval = tf.expand_dims(input_eval, 0) # Empty string to store each line sentence_ids = [] # Here batch size == 1 model.reset_states() while True: predictions = model(input_eval) # remove the batch dimension predictions = tf.squeeze(predictions, 0) # using a categorical distribution to # predict the word returned by the model predictions = predictions / store.gen_temp predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy() # We pass the predicted word as the next input to the model # along with the previous hidden state input_eval = tf.expand_dims([predicted_id], 0) sentence_ids.append(int(predicted_id)) decoded = sp.DecodeIds(sentence_ids) if store.field_delimiter is not None: decoded = decoded.replace(store.field_delimiter_token, store.field_delimiter) if "<n>" in decoded: return _pred_string(decoded.replace("<n>", "")) elif 0 < store.gen_chars <= len(decoded): return _pred_string(decoded)
def _predict_chars( model: tf.keras.Sequential, tokenizer: BaseTokenizer, start_string: Union[str, List[str]], store: TensorFlowConfig, predict_and_sample: Optional[Callable] = None, ) -> GeneratorType[PredString, None, None]: """ Evaluation step (generating text using the learned model). Args: model: tf.keras.Sequential model tokenizer: A subclass of BaseTokenizer start_string: string to bootstrap model. NOTE: this string MUST already have had special tokens inserted (i.e. <d>) store: our config object Returns: Yields line of text per iteration """ # Converting our start string to numbers (vectorizing) if isinstance(start_string, str): start_string = [start_string] _start_string = start_string[0] start_vec = tokenizer.encode_to_ids(_start_string) input_eval = tf.constant( [start_vec for _ in range(store.predict_batch_size)]) if predict_and_sample is None: def predict_and_sample(this_input): return _predict_and_sample(model, this_input, store.gen_temp) # Batch prediction batch_sentence_ids = [[] for _ in range(store.predict_batch_size)] not_done = set(i for i in range(store.predict_batch_size)) if store.reset_states: # Reset RNN model states between each record created # guarantees more consistent record creation over time, at the # expense of model accuracy model.reset_states() prediction_prefix = None if _start_string != tokenizer.newline_str: if store.field_delimiter is not None: prediction_prefix = tokenizer.detokenize_delimiter(_start_string) else: prediction_prefix = _start_string while not_done: input_eval = predict_and_sample(input_eval) for i in not_done: batch_sentence_ids[i].append(int(input_eval[i, 0].numpy())) batch_decoded = [(i, tokenizer.decode_from_ids(batch_sentence_ids[i])) for i in not_done] batch_decoded = _replace_prefix(batch_decoded, prediction_prefix) for i, decoded in batch_decoded: end_idx = decoded.find(tokenizer.newline_str) if end_idx >= 0: decoded = decoded[:end_idx] yield PredString(decoded) not_done.remove(i) elif 0 < store.gen_chars <= len(decoded): yield PredString(decoded) not_done.remove(i)