def test_predict_chars(mock_cat, tf_config, random_cat): config = tf_config tf_config.gen_chars = 10 mock_model = Mock(return_value=tf.constant([[[1.0]]])) mock_tensor = MagicMock() mock_tensor[-1, 0].numpy.return_value = 1 mock_cat.return_value = mock_tensor tokenizer = Mock() tokenizer.newline_str = NEWLINE tokenizer.encode_to_ids.return_value = [3] tokenizer.decode_from_ids.return_value = f"this is the end{NEWLINE}" # sp.DecodeIds.side_effect = ["this", " ", "is", " ", "the", " ", "end", "<n>"] line = next(_predict_chars(mock_model, tokenizer, NEWLINE, config)) assert line == PredString(data="this is the end") config = tf_config mock_tensor = MagicMock() mock_tensor[-1, 0].numpy.side_effect = [0, 1, 2, 3, 4, 5, 6, 7, 8] mock_cat.return_value = mock_tensor tf_config.gen_chars = 3 tokenizer = Mock() tokenizer.newline_str = NEWLINE tokenizer.encode_to_ids.return_value = [3] ret_data = [partial_rep for partial in ["a", "ab", "abc", "abcd"] for partial_rep in [partial] * config.predict_batch_size] tokenizer.decode_from_ids.side_effect = ret_data # sp.DecodeIds.side_effect = ["a", "b", "c", "d"] line = next(_predict_chars(mock_model, tokenizer, NEWLINE, config))
def test_predict_chars(mock_dims, mock_cat, global_local_config, random_cat): global_local_config.gen_chars = 10 mock_model = Mock(return_value=[1.0]) mock_tensor = MagicMock() mock_tensor[-1, 0].numpy.return_value = 1 mock_cat.return_value = mock_tensor sp = Mock() sp.DecodeIds.return_value = "this is the end<n>" line = _predict_chars(mock_model, sp, "\n", global_local_config) assert line == PredString(data="this is the end") mock_tensor = MagicMock() mock_tensor[-1, 0].numpy.side_effect = [0, 1, 2, 3, 4, 5, 6, 7, 8] mock_cat.return_value = mock_tensor global_local_config.gen_chars = 3 sp = Mock() sp.DecodeIds.side_effect = ["a", "ab", "abc", "abcd"] line = _predict_chars(mock_model, sp, "\n", global_local_config) assert line.data == "abc"
def _predict_chars( model: tf.keras.Sequential, tokenizer: BaseTokenizer, start_string: Union[str, List[str]], store: TensorFlowConfig, predict_and_sample: Optional[Callable] = None, ) -> GeneratorType[PredString, None, None]: """ Evaluation step (generating text using the learned model). Args: model: tf.keras.Sequential model tokenizer: A subclass of BaseTokenizer start_string: string to bootstrap model. NOTE: this string MUST already have had special tokens inserted (i.e. <d>) store: our config object Returns: Yields line of text per iteration """ # Converting our start string to numbers (vectorizing) if isinstance(start_string, str): start_string = [start_string] _start_string = start_string[0] start_vec = tokenizer.encode_to_ids(_start_string) input_eval = tf.constant( [start_vec for _ in range(store.predict_batch_size)]) if predict_and_sample is None: def predict_and_sample(this_input): return _predict_and_sample(model, this_input, store.gen_temp) # Batch prediction batch_sentence_ids = [[] for _ in range(store.predict_batch_size)] not_done = set(i for i in range(store.predict_batch_size)) if store.reset_states: # Reset RNN model states between each record created # guarantees more consistent record creation over time, at the # expense of model accuracy model.reset_states() prediction_prefix = None if _start_string != tokenizer.newline_str: if store.field_delimiter is not None: prediction_prefix = tokenizer.detokenize_delimiter(_start_string) else: prediction_prefix = _start_string while not_done: input_eval = predict_and_sample(input_eval) for i in not_done: batch_sentence_ids[i].append(int(input_eval[i, 0].numpy())) batch_decoded = [(i, tokenizer.decode_from_ids(batch_sentence_ids[i])) for i in not_done] batch_decoded = _replace_prefix(batch_decoded, prediction_prefix) for i, decoded in batch_decoded: end_idx = decoded.find(tokenizer.newline_str) if end_idx >= 0: decoded = decoded[:end_idx] yield PredString(decoded) not_done.remove(i) elif 0 < store.gen_chars <= len(decoded): yield PredString(decoded) not_done.remove(i)
def test_generate_text(_open, pickle, prepare, predict, spm, tf_config): tf_config.gen_lines = 10 predict.side_effect = [[PredString(json.dumps({"foo": i}))] for i in range(0, 10)] out = [] tokenizer = Mock() spm.return_value = tokenizer for rec in generate_text(tf_config, line_validator=json.loads, parallelism=1): out.append(rec.as_dict()) assert len(out) == 10 assert out[0] == { "valid": True, "text": '{"foo": 0}', "explain": None, "delimiter": ",", } # now with no validator predict.side_effect = [[PredString(json.dumps({"foo": i}))] for i in range(0, 10)] out = [] for rec in generate_text(tf_config, parallelism=1): out.append(rec.as_dict()) assert len(out) == 10 assert out[0] == { "valid": None, "text": '{"foo": 0}', "explain": None, "delimiter": ",", } # add validator back in, with a few bad json strings predict.side_effect = [ [PredString(json.dumps({"foo": i})) for i in range(0, 3)], [PredString("nope"), PredString("foo"), PredString("bar")], [PredString(json.dumps({"foo": i})) for i in range(6, 10)], ] out = [] try: for rec in generate_text(tf_config, line_validator=json.loads, parallelism=1): out.append(rec.as_dict()) except RuntimeError: pass assert len(out) == 10 assert not out[4]["valid"] # assert max invalid predict.side_effect = [ [PredString(json.dumps({"foo": i})) for i in range(0, 3)], [PredString("nope"), PredString("foo"), PredString("bar")], [PredString(json.dumps({"foo": i})) for i in range(6, 10)], ] out = [] try: for rec in generate_text(tf_config, line_validator=json.loads, max_invalid=2, parallelism=1): out.append(rec.as_dict()) except RuntimeError as err: assert "Maximum number" in str(err) assert len(out) == 6 assert not out[4]["valid"] # max invalid, validator returns a bool def _val(line): try: json.loads(line) except Exception: return False else: return True predict.side_effect = [ [PredString(json.dumps({"foo": i})) for i in range(0, 3)], [PredString("nope"), PredString("foo"), PredString("bar")], [PredString(json.dumps({"foo": i})) for i in range(6, 10)], ] out = [] try: for rec in generate_text(tf_config, line_validator=_val, max_invalid=2, parallelism=1): out.append(rec.as_dict()) except RuntimeError as err: assert "Maximum number" in str(err) assert len(out) == 6 assert not out[4]["valid"]