def test_generate_all_batch_lines_raise_on_failed(test_data): batches = DataFrameBatch(df=test_data, config=config_template) batches.create_training_data() batches.generate_batch_lines = Mock() batches.generate_all_batch_lines() _, args, kwargs = batches.generate_batch_lines.mock_calls[0] assert args == (0,) assert kwargs == { "max_invalid": MAX_INVALID, "raise_on_exceed_invalid": False, "num_lines": None, "parallelism": 0, "seed_fields": None, } batches.generate_batch_lines = Mock() batches.generate_all_batch_lines( max_invalid=10, raise_on_failed_batch=True, num_lines=5 ) _, args, kwargs = batches.generate_batch_lines.mock_calls[0] assert args == (0,) assert kwargs == { "max_invalid": 10, "raise_on_exceed_invalid": True, "num_lines": 5, "parallelism": 0, "seed_fields": None, }
def test_train_batch_sp_tok(train_df, tmp_path): config = TensorFlowConfig( epochs=5, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER, learning_rate=.01 ) tokenizer = SentencePieceTokenizerTrainer( vocab_size=10000, config=config ) batcher = DataFrameBatch( df=train_df, config=config, tokenizer=tokenizer ) batcher.create_training_data() batcher.train_all_batches() batcher.generate_all_batch_lines(num_lines=_tok_gen_count, max_invalid=5000) syn_df = batcher.batches_to_df() assert syn_df.shape[0] == _tok_gen_count # Generate with a RecordFactory factory = batcher.create_record_factory(num_lines=_tok_gen_count, max_invalid=5000) syn_df = factory.generate_all(output="df") assert syn_df.shape[0] == _tok_gen_count assert list(syn_df.columns) == list(train_df.columns) assert factory.summary["valid_count"] == _tok_gen_count
def test_train_batch_sp_tok(train_df, tmp_path): config = TensorFlowConfig(epochs=5, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER, learning_rate=.01) tokenizer = SentencePieceTokenizerTrainer(vocab_size=10000, config=config) batcher = DataFrameBatch(df=train_df, config=config, tokenizer=tokenizer) batcher.create_training_data() batcher.train_all_batches() batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000) syn_df = batcher.batches_to_df() assert syn_df.shape[0] == 100
def _fit_sample(self, data, metadata): config = { 'max_lines': self.max_lines, 'max_line_len': self.max_line_len, 'epochs': self.epochs or data.shape[1] * 3, # value recommended by Gretel 'vocab_size': self.vocab_size, 'gen_lines': self.gen_lines or data.shape[0], 'dp': self.dp, 'field_delimiter': self.field_delimiter, 'overwrite': self.overwrite, 'checkpoint_dir': self.checkpoint_dir } batcher = DataFrameBatch(df=data, config=config) batcher.create_training_data() batcher.train_all_batches() batcher.generate_all_batch_lines() synth_data = batcher.batches_to_df() return synth_data
def test_train_batch_char_tok(train_df, tmp_path): config = TensorFlowConfig(epochs=5, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER, learning_rate=.01) batcher = DataFrameBatch(df=train_df, config=config, tokenizer=CharTokenizerTrainer(config=config)) batcher.create_training_data() batcher.train_all_batches() tok_params = json.loads( open(tmp_path / "batch_0" / BaseTokenizerTrainer.settings_fname).read()) assert tok_params["tokenizer_type"] == CharTokenizerTrainer.__name__ batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000) syn_df = batcher.batches_to_df() assert syn_df.shape[0] == 100
def _generate( self, model_dir: Path, count: int, file_name: str, seed, validator ) -> str: batch_mode = is_model_dir_batch_mode(model_dir) if batch_mode: if seed is not None and not isinstance(seed, dict): raise TypeError("Seed must be a dict in batch mode") out_fname = f"{file_name}.csv" batcher = DataFrameBatch(mode="read", checkpoint_dir=str(model_dir)) batcher.generate_all_batch_lines( num_lines=count, max_invalid=max(count, MAX_INVALID), parallelism=1, seed_fields=seed ) out_df = batcher.batches_to_df() out_df.to_csv(out_fname, index=False) return out_fname else: out = [] # The model data will be inside of a single directory when a simple model is used. If it # was archived correctly, there should only be a single directory inside the archive actual_dir = next(model_dir.glob("*")) config = config_from_model_dir(actual_dir) if seed is not None and not isinstance(seed, str): raise TypeError("seed must be a string") for data in generate_text( config, num_lines=count, line_validator=validator, max_invalid=max(count, MAX_INVALID), parallelism=1, start_string=seed ): if data.valid or data.valid is None: out.append(data.text) out_fname = file_name + ".txt" with open(out_fname, "w") as fout: for line in out: fout.write(line + "\n") return out_fname
def test_init(test_data): with pytest.raises(ValueError): DataFrameBatch(df="nope", config=config_template) # should create the dir structure based on auto # batch sizing batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15) first_row = [ "ID_code", "target", "var_0", "var_1", "var_2", "var_3", "var_4", "var_5", "var_6", "var_7", "var_8", "var_9", "var_10", "var_11", "var_12", ] assert batches.batches[0].headers == first_row assert len(batches.batches.keys()) == 14 for i, batch in batches.batches.items(): assert Path(batch.checkpoint_dir).is_dir() assert Path(batch.checkpoint_dir).name == f"batch_{i}" orig_headers = json.loads( open(Path(config_template["checkpoint_dir"]) / ORIG_HEADERS).read() ) assert list(set(orig_headers)) == list(set(test_data.columns)) batches.create_training_data() df = pd.read_csv( batches.batches[0].input_data_path, sep=config_template["field_delimiter"] ) assert len(df.columns) == len(first_row) with pytest.raises(ValueError): batches.train_batch(99) with patch("gretel_synthetics.batch.train") as mock_train: batches.train_batch(5) arg = batches.batches[5].config mock_train.assert_called_with(arg, None) with patch("gretel_synthetics.batch.train") as mock_train: batches.train_all_batches() args = [b.config for b in batches.batches.values()] called_args = [] for _, a, _ in mock_train.mock_calls: called_args.append(a[0]) assert args == called_args with pytest.raises(ValueError): batches.set_batch_validator(5, "foo") with pytest.raises(ValueError): batches.set_batch_validator(99, simple_validator) batches.set_batch_validator(5, simple_validator) assert batches.batches[5].validator("1,2,3,4,5") # load validator back from disk batches.batches[5].load_validator_from_file() assert batches.batches[5].validator("1,2,3,4,5") # generate lines, simulating generation the max # valid line count def good(): return GenText( text="1,2,3,4,5", valid=random.choice([None, True]), delimiter="," ) def bad(): return GenText(text="1,2,3", valid=False, delimiter=",") with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()] summary = batches.generate_batch_lines(5, max_invalid=1) assert summary.is_valid check_call = mock_gen.mock_calls[0] _, _, kwargs = check_call assert kwargs["max_invalid"] == 1 with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()] summary = batches.generate_batch_lines(5) assert summary.is_valid with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [good(), good(), good(), bad(), bad(), good()] summary = batches.generate_batch_lines(5) assert not summary.is_valid with patch.object(batches, "generate_batch_lines") as mock_gen: batches.generate_all_batch_lines(max_invalid=15) assert mock_gen.call_count == len(batches.batches.keys()) check_call = mock_gen.mock_calls[0] _, _, kwargs = check_call assert kwargs["max_invalid"] == 15 # get synthetic df line = GenText( text="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15", valid=True, delimiter="," ) with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [line] * len(batches.batches[10].headers) batches.generate_batch_lines(10) assert len(batches.batches[10].synthetic_df) == len(batches.batches[10].headers)