예제 #1
0
def test_predict_chars(mock_cat, tf_config, random_cat):
    config = tf_config

    tf_config.gen_chars = 10
    mock_model = Mock(return_value=tf.constant([[[1.0]]]))
    mock_tensor = MagicMock()
    mock_tensor[-1, 0].numpy.return_value = 1
    mock_cat.return_value = mock_tensor

    tokenizer = Mock()
    tokenizer.newline_str = NEWLINE
    tokenizer.encode_to_ids.return_value = [3]
    tokenizer.decode_from_ids.return_value = f"this is the end{NEWLINE}"
    # sp.DecodeIds.side_effect = ["this", " ", "is", " ", "the", " ", "end", "<n>"]
    line = next(_predict_chars(mock_model, tokenizer, NEWLINE, config))
    assert line == PredString(data="this is the end")
 
    config = tf_config
    mock_tensor = MagicMock()
    mock_tensor[-1, 0].numpy.side_effect = [0, 1, 2, 3, 4, 5, 6, 7, 8]
    mock_cat.return_value = mock_tensor
    tf_config.gen_chars = 3
    tokenizer = Mock()
    tokenizer.newline_str = NEWLINE
    tokenizer.encode_to_ids.return_value = [3]
    ret_data = [partial_rep for partial in ["a", "ab", "abc", "abcd"] for partial_rep in [partial] * config.predict_batch_size]
    tokenizer.decode_from_ids.side_effect = ret_data
    # sp.DecodeIds.side_effect = ["a", "b", "c", "d"]
    line = next(_predict_chars(mock_model, tokenizer, NEWLINE, config))
예제 #2
0
def test_predict_chars(mock_dims, mock_cat, global_local_config, random_cat):
    global_local_config.gen_chars = 10
    mock_model = Mock(return_value=[1.0])
    mock_tensor = MagicMock()
    mock_tensor[-1, 0].numpy.return_value = 1
    mock_cat.return_value = mock_tensor

    sp = Mock()
    sp.DecodeIds.return_value = "this is the end<n>"

    line = _predict_chars(mock_model, sp, "\n", global_local_config)
    assert line == PredString(data="this is the end")

    mock_tensor = MagicMock()
    mock_tensor[-1, 0].numpy.side_effect = [0, 1, 2, 3, 4, 5, 6, 7, 8]
    mock_cat.return_value = mock_tensor
    global_local_config.gen_chars = 3
    sp = Mock()
    sp.DecodeIds.side_effect = ["a", "ab", "abc", "abcd"]
    line = _predict_chars(mock_model, sp, "\n", global_local_config)
    assert line.data == "abc"
예제 #3
0
def _predict_chars(
    model: tf.keras.Sequential,
    tokenizer: BaseTokenizer,
    start_string: Union[str, List[str]],
    store: TensorFlowConfig,
    predict_and_sample: Optional[Callable] = None,
) -> GeneratorType[PredString, None, None]:
    """
    Evaluation step (generating text using the learned model).

    Args:
        model: tf.keras.Sequential model
        tokenizer: A subclass of BaseTokenizer
        start_string: string to bootstrap model. NOTE: this string MUST already have had special tokens
            inserted (i.e. <d>)
        store: our config object
    Returns:
        Yields line of text per iteration
    """

    # Converting our start string to numbers (vectorizing)
    if isinstance(start_string, str):
        start_string = [start_string]

    _start_string = start_string[0]

    start_vec = tokenizer.encode_to_ids(_start_string)
    input_eval = tf.constant(
        [start_vec for _ in range(store.predict_batch_size)])

    if predict_and_sample is None:

        def predict_and_sample(this_input):
            return _predict_and_sample(model, this_input, store.gen_temp)

    # Batch prediction
    batch_sentence_ids = [[] for _ in range(store.predict_batch_size)]
    not_done = set(i for i in range(store.predict_batch_size))

    if store.reset_states:
        # Reset RNN model states between each record created
        # guarantees more consistent record creation over time, at the
        # expense of model accuracy
        model.reset_states()

    prediction_prefix = None
    if _start_string != tokenizer.newline_str:
        if store.field_delimiter is not None:
            prediction_prefix = tokenizer.detokenize_delimiter(_start_string)
        else:
            prediction_prefix = _start_string

    while not_done:
        input_eval = predict_and_sample(input_eval)
        for i in not_done:
            batch_sentence_ids[i].append(int(input_eval[i, 0].numpy()))

        batch_decoded = [(i, tokenizer.decode_from_ids(batch_sentence_ids[i]))
                         for i in not_done]
        batch_decoded = _replace_prefix(batch_decoded, prediction_prefix)
        for i, decoded in batch_decoded:
            end_idx = decoded.find(tokenizer.newline_str)
            if end_idx >= 0:
                decoded = decoded[:end_idx]
                yield PredString(decoded)
                not_done.remove(i)
            elif 0 < store.gen_chars <= len(decoded):
                yield PredString(decoded)
                not_done.remove(i)
예제 #4
0
def test_generate_text(_open, pickle, prepare, predict, spm, tf_config):
    tf_config.gen_lines = 10
    predict.side_effect = [[PredString(json.dumps({"foo": i}))] for i in range(0, 10)]
    out = []

    tokenizer = Mock()
    spm.return_value = tokenizer

    for rec in generate_text(tf_config, line_validator=json.loads, parallelism=1):
        out.append(rec.as_dict())

    assert len(out) == 10
    assert out[0] == {
        "valid": True,
        "text": '{"foo": 0}',
        "explain": None,
        "delimiter": ",",
    }

    # now with no validator
    predict.side_effect = [[PredString(json.dumps({"foo": i}))] for i in range(0, 10)]
    out = []
    for rec in generate_text(tf_config, parallelism=1):
        out.append(rec.as_dict())
    assert len(out) == 10
    assert out[0] == {
        "valid": None,
        "text": '{"foo": 0}',
        "explain": None,
        "delimiter": ",",
    }

    # add validator back in, with a few bad json strings
    predict.side_effect = [
        [PredString(json.dumps({"foo": i})) for i in range(0, 3)],
        [PredString("nope"), PredString("foo"), PredString("bar")],
        [PredString(json.dumps({"foo": i})) for i in range(6, 10)],
    ]
    out = []
    try:
        for rec in generate_text(tf_config, line_validator=json.loads, parallelism=1):
            out.append(rec.as_dict())
    except RuntimeError:
        pass
    assert len(out) == 10
    assert not out[4]["valid"]

    # assert max invalid
    predict.side_effect = [
        [PredString(json.dumps({"foo": i})) for i in range(0, 3)],
        [PredString("nope"), PredString("foo"), PredString("bar")],
        [PredString(json.dumps({"foo": i})) for i in range(6, 10)],
    ]
    out = []
    try:
        for rec in generate_text(tf_config, line_validator=json.loads, max_invalid=2, parallelism=1):
            out.append(rec.as_dict())
    except RuntimeError as err:
        assert "Maximum number" in str(err)
    assert len(out) == 6
    assert not out[4]["valid"]

    # max invalid, validator returns a bool
    def _val(line):
        try:
            json.loads(line)
        except Exception:
            return False
        else:
            return True

    predict.side_effect = [
        [PredString(json.dumps({"foo": i})) for i in range(0, 3)],
        [PredString("nope"), PredString("foo"), PredString("bar")],
        [PredString(json.dumps({"foo": i})) for i in range(6, 10)],
    ]
    out = []
    try:
        for rec in generate_text(tf_config, line_validator=_val, max_invalid=2, parallelism=1):
            out.append(rec.as_dict())
    except RuntimeError as err:
        assert "Maximum number" in str(err)
    assert len(out) == 6
    assert not out[4]["valid"]