def test_generate_all_batch_lines_raise_on_failed(test_data): batches = DataFrameBatch(df=test_data, config=config_template) batches.create_training_data() batches.generate_batch_lines = Mock() batches.generate_all_batch_lines() _, args, kwargs = batches.generate_batch_lines.mock_calls[0] assert args == (0,) assert kwargs == { "max_invalid": MAX_INVALID, "raise_on_exceed_invalid": False, "num_lines": None, "parallelism": 0, "seed_fields": None, } batches.generate_batch_lines = Mock() batches.generate_all_batch_lines( max_invalid=10, raise_on_failed_batch=True, num_lines=5 ) _, args, kwargs = batches.generate_batch_lines.mock_calls[0] assert args == (0,) assert kwargs == { "max_invalid": 10, "raise_on_exceed_invalid": True, "num_lines": 5, "parallelism": 0, "seed_fields": None, }
def test_generate_batch_lines_raise_on_exceed(test_data): batches = DataFrameBatch(df=test_data, config=config_template) batches.create_training_data() with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.side_effect = TooManyInvalidError() assert not batches.generate_batch_lines(0) with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.side_effect = TooManyInvalidError() with pytest.raises(TooManyInvalidError): assert not batches.generate_batch_lines(0, raise_on_exceed_invalid=True)
def test_init(test_data): with pytest.raises(ValueError): DataFrameBatch(df="nope", config=config_template) # should create the dir structure based on auto # batch sizing batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15) first_row = [ "ID_code", "target", "var_0", "var_1", "var_2", "var_3", "var_4", "var_5", "var_6", "var_7", "var_8", "var_9", "var_10", "var_11", "var_12", ] assert batches.batches[0].headers == first_row assert len(batches.batches.keys()) == 14 for i, batch in batches.batches.items(): assert Path(batch.checkpoint_dir).is_dir() assert Path(batch.checkpoint_dir).name == f"batch_{i}" orig_headers = json.loads( open(Path(config_template["checkpoint_dir"]) / ORIG_HEADERS).read() ) assert list(set(orig_headers)) == list(set(test_data.columns)) batches.create_training_data() df = pd.read_csv( batches.batches[0].input_data_path, sep=config_template["field_delimiter"] ) assert len(df.columns) == len(first_row) with pytest.raises(ValueError): batches.train_batch(99) with patch("gretel_synthetics.batch.train") as mock_train: batches.train_batch(5) arg = batches.batches[5].config mock_train.assert_called_with(arg, None) with patch("gretel_synthetics.batch.train") as mock_train: batches.train_all_batches() args = [b.config for b in batches.batches.values()] called_args = [] for _, a, _ in mock_train.mock_calls: called_args.append(a[0]) assert args == called_args with pytest.raises(ValueError): batches.set_batch_validator(5, "foo") with pytest.raises(ValueError): batches.set_batch_validator(99, simple_validator) batches.set_batch_validator(5, simple_validator) assert batches.batches[5].validator("1,2,3,4,5") # load validator back from disk batches.batches[5].load_validator_from_file() assert batches.batches[5].validator("1,2,3,4,5") # generate lines, simulating generation the max # valid line count def good(): return GenText( text="1,2,3,4,5", valid=random.choice([None, True]), delimiter="," ) def bad(): return GenText(text="1,2,3", valid=False, delimiter=",") with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()] summary = batches.generate_batch_lines(5, max_invalid=1) assert summary.is_valid check_call = mock_gen.mock_calls[0] _, _, kwargs = check_call assert kwargs["max_invalid"] == 1 with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()] summary = batches.generate_batch_lines(5) assert summary.is_valid with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [good(), good(), good(), bad(), bad(), good()] summary = batches.generate_batch_lines(5) assert not summary.is_valid with patch.object(batches, "generate_batch_lines") as mock_gen: batches.generate_all_batch_lines(max_invalid=15) assert mock_gen.call_count == len(batches.batches.keys()) check_call = mock_gen.mock_calls[0] _, _, kwargs = check_call assert kwargs["max_invalid"] == 15 # get synthetic df line = GenText( text="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15", valid=True, delimiter="," ) with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [line] * len(batches.batches[10].headers) batches.generate_batch_lines(10) assert len(batches.batches[10].synthetic_df) == len(batches.batches[10].headers)