def test_epoch_callback(train_df, tmp_path): def epoch_callback(s: EpochState): with open(tmp_path / 'callback_dump.txt', 'a') as f: f.write(f'{s.epoch},{s.accuracy},{s.loss},{s.batch}\n') config = TensorFlowConfig(epochs=5, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER, learning_rate=.01, epoch_callback=epoch_callback) tokenizer = SentencePieceTokenizerTrainer(vocab_size=10000, config=config) batcher = DataFrameBatch(batch_size=4, df=train_df, config=config, tokenizer=tokenizer) batcher.create_training_data() batcher.train_all_batches() with open(tmp_path / 'callback_dump.txt', 'r') as f: lines = f.readlines() assert len(lines) == 20 for i, line in enumerate(lines): fields = line.strip().split(',') assert len(fields) == 4 assert int(fields[0]) == i % 5 assert int(fields[3]) == i // 5 float(fields[1]) float(fields[2]) os.remove(tmp_path / 'callback_dump.txt')
def test_train_batch_sp_tok(train_df, tmp_path): config = TensorFlowConfig( epochs=5, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER, learning_rate=.01 ) tokenizer = SentencePieceTokenizerTrainer( vocab_size=10000, config=config ) batcher = DataFrameBatch( df=train_df, config=config, tokenizer=tokenizer ) batcher.create_training_data() batcher.train_all_batches() batcher.generate_all_batch_lines(num_lines=_tok_gen_count, max_invalid=5000) syn_df = batcher.batches_to_df() assert syn_df.shape[0] == _tok_gen_count # Generate with a RecordFactory factory = batcher.create_record_factory(num_lines=_tok_gen_count, max_invalid=5000) syn_df = factory.generate_all(output="df") assert syn_df.shape[0] == _tok_gen_count assert list(syn_df.columns) == list(train_df.columns) assert factory.summary["valid_count"] == _tok_gen_count
def test_train_small_df(train_df, tmp_path): small_df = train_df.sample(n=50) config = TensorFlowConfig(epochs=5, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER) batcher = DataFrameBatch(df=small_df, config=config) batcher.create_training_data() with pytest.raises(RuntimeError) as excinfo: batcher.train_all_batches() assert "Model training failed" in str(excinfo.value)
def test_train_batch_sp_tok(train_df, tmp_path): config = TensorFlowConfig(epochs=5, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER, learning_rate=.01) tokenizer = SentencePieceTokenizerTrainer(vocab_size=10000, config=config) batcher = DataFrameBatch(df=train_df, config=config, tokenizer=tokenizer) batcher.create_training_data() batcher.train_all_batches() batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000) syn_df = batcher.batches_to_df() assert syn_df.shape[0] == 100
def test_train_batch_sp(train_df, tmp_path): config = TensorFlowConfig(epochs=1, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER) batcher = DataFrameBatch(df=train_df, config=config) batcher.create_training_data() batcher.train_all_batches() model_params = json.loads( open(tmp_path / "batch_0" / const.MODEL_PARAMS).read()) assert model_params[const.MODEL_TYPE] == TensorFlowConfig.__name__ tok_params = json.loads( open(tmp_path / "batch_0" / BaseTokenizerTrainer.settings_fname).read()) assert tok_params[ "tokenizer_type"] == SentencePieceTokenizerTrainer.__name__
def test_train_batch_sp_regression(train_df, tmp_path): """Batch mode with default SentencePiece tokenizer. Using the backwards compat mode for <= 0.14.0. """ config = {"epochs": 1, "field_delimiter": ",", "checkpoint_dir": tmp_path} batcher = DataFrameBatch(df=train_df, config=config) batcher.create_training_data() batcher.train_all_batches() model_params = json.loads( open(tmp_path / "batch_0" / const.MODEL_PARAMS).read()) assert model_params[const.MODEL_TYPE] == TensorFlowConfig.__name__ tok_params = json.loads( open(tmp_path / "batch_0" / BaseTokenizerTrainer.settings_fname).read()) assert tok_params[ "tokenizer_type"] == SentencePieceTokenizerTrainer.__name__
def _fit_sample(self, data, metadata): config = { 'max_lines': self.max_lines, 'max_line_len': self.max_line_len, 'epochs': self.epochs or data.shape[1] * 3, # value recommended by Gretel 'vocab_size': self.vocab_size, 'gen_lines': self.gen_lines or data.shape[0], 'dp': self.dp, 'field_delimiter': self.field_delimiter, 'overwrite': self.overwrite, 'checkpoint_dir': self.checkpoint_dir } batcher = DataFrameBatch(df=data, config=config) batcher.create_training_data() batcher.train_all_batches() batcher.generate_all_batch_lines() synth_data = batcher.batches_to_df() return synth_data
def test_train_batch_char_tok(train_df, tmp_path): config = TensorFlowConfig(epochs=5, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER, learning_rate=.01) batcher = DataFrameBatch(df=train_df, config=config, tokenizer=CharTokenizerTrainer(config=config)) batcher.create_training_data() batcher.train_all_batches() tok_params = json.loads( open(tmp_path / "batch_0" / BaseTokenizerTrainer.settings_fname).read()) assert tok_params["tokenizer_type"] == CharTokenizerTrainer.__name__ batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000) syn_df = batcher.batches_to_df() assert syn_df.shape[0] == 100
def test_init(test_data): with pytest.raises(ValueError): DataFrameBatch(df="nope", config=config_template) # should create the dir structure based on auto # batch sizing batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15) first_row = [ "ID_code", "target", "var_0", "var_1", "var_2", "var_3", "var_4", "var_5", "var_6", "var_7", "var_8", "var_9", "var_10", "var_11", "var_12", ] assert batches.batches[0].headers == first_row assert len(batches.batches.keys()) == 14 for i, batch in batches.batches.items(): assert Path(batch.checkpoint_dir).is_dir() assert Path(batch.checkpoint_dir).name == f"batch_{i}" orig_headers = json.loads( open(Path(config_template["checkpoint_dir"]) / ORIG_HEADERS).read() ) assert list(set(orig_headers)) == list(set(test_data.columns)) batches.create_training_data() df = pd.read_csv( batches.batches[0].input_data_path, sep=config_template["field_delimiter"] ) assert len(df.columns) == len(first_row) with pytest.raises(ValueError): batches.train_batch(99) with patch("gretel_synthetics.batch.train") as mock_train: batches.train_batch(5) arg = batches.batches[5].config mock_train.assert_called_with(arg, None) with patch("gretel_synthetics.batch.train") as mock_train: batches.train_all_batches() args = [b.config for b in batches.batches.values()] called_args = [] for _, a, _ in mock_train.mock_calls: called_args.append(a[0]) assert args == called_args with pytest.raises(ValueError): batches.set_batch_validator(5, "foo") with pytest.raises(ValueError): batches.set_batch_validator(99, simple_validator) batches.set_batch_validator(5, simple_validator) assert batches.batches[5].validator("1,2,3,4,5") # load validator back from disk batches.batches[5].load_validator_from_file() assert batches.batches[5].validator("1,2,3,4,5") # generate lines, simulating generation the max # valid line count def good(): return GenText( text="1,2,3,4,5", valid=random.choice([None, True]), delimiter="," ) def bad(): return GenText(text="1,2,3", valid=False, delimiter=",") with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()] summary = batches.generate_batch_lines(5, max_invalid=1) assert summary.is_valid check_call = mock_gen.mock_calls[0] _, _, kwargs = check_call assert kwargs["max_invalid"] == 1 with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()] summary = batches.generate_batch_lines(5) assert summary.is_valid with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [good(), good(), good(), bad(), bad(), good()] summary = batches.generate_batch_lines(5) assert not summary.is_valid with patch.object(batches, "generate_batch_lines") as mock_gen: batches.generate_all_batch_lines(max_invalid=15) assert mock_gen.call_count == len(batches.batches.keys()) check_call = mock_gen.mock_calls[0] _, _, kwargs = check_call assert kwargs["max_invalid"] == 15 # get synthetic df line = GenText( text="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15", valid=True, delimiter="," ) with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [line] * len(batches.batches[10].headers) batches.generate_batch_lines(10) assert len(batches.batches[10].synthetic_df) == len(batches.batches[10].headers)