Example #1
0
def _predict_chars(
    model: tf.keras.Sequential,
    sp: spm.SentencePieceProcessor,
    start_string: str,
    store: _BaseConfig,
) -> str:
    """
    Evaluation step (generating text using the learned model).

    Args:
        model: tf.keras.Sequential model
        sp: SentencePiece tokenizer
        start_string: string to bootstrap model
        store: our config object
    Returns:
        Yields line of text per iteration
    """

    # Converting our start string to numbers (vectorizing)
    input_eval = sp.EncodeAsIds(start_string)
    input_eval = tf.expand_dims(input_eval, 0)

    # Empty string to store each line
    sentence_ids = []

    # Here batch size == 1
    model.reset_states()

    while True:
        predictions = model(input_eval)
        # remove the batch dimension
        predictions = tf.squeeze(predictions, 0)

        # using a categorical distribution to
        # predict the word returned by the model
        predictions = predictions / store.gen_temp
        predicted_id = tf.random.categorical(predictions,
                                             num_samples=1)[-1, 0].numpy()

        # We pass the predicted word as the next input to the model
        # along with the previous hidden state
        input_eval = tf.expand_dims([predicted_id], 0)
        sentence_ids.append(int(predicted_id))

        decoded = sp.DecodeIds(sentence_ids)
        if store.field_delimiter is not None:
            decoded = decoded.replace(store.field_delimiter_token,
                                      store.field_delimiter)

        if "<n>" in decoded:
            return _pred_string(decoded.replace("<n>", ""))
        elif 0 < store.gen_chars <= len(decoded):
            return _pred_string(decoded)
Example #2
0
def _predict_chars(
    model: tf.keras.Sequential,
    tokenizer: BaseTokenizer,
    start_string: Union[str, List[str]],
    store: TensorFlowConfig,
    predict_and_sample: Optional[Callable] = None,
) -> GeneratorType[PredString, None, None]:
    """
    Evaluation step (generating text using the learned model).

    Args:
        model: tf.keras.Sequential model
        tokenizer: A subclass of BaseTokenizer
        start_string: string to bootstrap model. NOTE: this string MUST already have had special tokens
            inserted (i.e. <d>)
        store: our config object
    Returns:
        Yields line of text per iteration
    """

    # Converting our start string to numbers (vectorizing)
    if isinstance(start_string, str):
        start_string = [start_string]

    _start_string = start_string[0]

    start_vec = tokenizer.encode_to_ids(_start_string)
    input_eval = tf.constant(
        [start_vec for _ in range(store.predict_batch_size)])

    if predict_and_sample is None:

        def predict_and_sample(this_input):
            return _predict_and_sample(model, this_input, store.gen_temp)

    # Batch prediction
    batch_sentence_ids = [[] for _ in range(store.predict_batch_size)]
    not_done = set(i for i in range(store.predict_batch_size))

    if store.reset_states:
        # Reset RNN model states between each record created
        # guarantees more consistent record creation over time, at the
        # expense of model accuracy
        model.reset_states()

    prediction_prefix = None
    if _start_string != tokenizer.newline_str:
        if store.field_delimiter is not None:
            prediction_prefix = tokenizer.detokenize_delimiter(_start_string)
        else:
            prediction_prefix = _start_string

    while not_done:
        input_eval = predict_and_sample(input_eval)
        for i in not_done:
            batch_sentence_ids[i].append(int(input_eval[i, 0].numpy()))

        batch_decoded = [(i, tokenizer.decode_from_ids(batch_sentence_ids[i]))
                         for i in not_done]
        batch_decoded = _replace_prefix(batch_decoded, prediction_prefix)
        for i, decoded in batch_decoded:
            end_idx = decoded.find(tokenizer.newline_str)
            if end_idx >= 0:
                decoded = decoded[:end_idx]
                yield PredString(decoded)
                not_done.remove(i)
            elif 0 < store.gen_chars <= len(decoded):
                yield PredString(decoded)
                not_done.remove(i)