Ejemplo n.º 1
0
def test_generate_all_batch_lines_raise_on_failed(test_data):
    batches = DataFrameBatch(df=test_data, config=config_template)
    batches.create_training_data()

    batches.generate_batch_lines = Mock()
    batches.generate_all_batch_lines()
    _, args, kwargs = batches.generate_batch_lines.mock_calls[0]
    assert args == (0,)
    assert kwargs == {
        "max_invalid": MAX_INVALID,
        "raise_on_exceed_invalid": False,
        "num_lines": None,
        "parallelism": 0,
        "seed_fields": None,
    }

    batches.generate_batch_lines = Mock()
    batches.generate_all_batch_lines(
        max_invalid=10, raise_on_failed_batch=True, num_lines=5
    )
    _, args, kwargs = batches.generate_batch_lines.mock_calls[0]
    assert args == (0,)
    assert kwargs == {
        "max_invalid": 10,
        "raise_on_exceed_invalid": True,
        "num_lines": 5,
        "parallelism": 0,
        "seed_fields": None,
    }
Ejemplo n.º 2
0
def test_generate_batch_lines_raise_on_exceed(test_data):
    batches = DataFrameBatch(df=test_data, config=config_template)
    batches.create_training_data()

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.side_effect = TooManyInvalidError()
        assert not batches.generate_batch_lines(0)

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.side_effect = TooManyInvalidError()
        with pytest.raises(TooManyInvalidError):
            assert not batches.generate_batch_lines(0, raise_on_exceed_invalid=True)
Ejemplo n.º 3
0
def test_init(test_data):
    with pytest.raises(ValueError):
        DataFrameBatch(df="nope", config=config_template)

    # should create the dir structure based on auto
    # batch sizing
    batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15)
    first_row = [
        "ID_code",
        "target",
        "var_0",
        "var_1",
        "var_2",
        "var_3",
        "var_4",
        "var_5",
        "var_6",
        "var_7",
        "var_8",
        "var_9",
        "var_10",
        "var_11",
        "var_12",
    ]
    assert batches.batches[0].headers == first_row
    assert len(batches.batches.keys()) == 14
    for i, batch in batches.batches.items():
        assert Path(batch.checkpoint_dir).is_dir()
        assert Path(batch.checkpoint_dir).name == f"batch_{i}"

    orig_headers = json.loads(
        open(Path(config_template["checkpoint_dir"]) / ORIG_HEADERS).read()
    )
    assert list(set(orig_headers)) == list(set(test_data.columns))

    batches.create_training_data()
    df = pd.read_csv(
        batches.batches[0].input_data_path, sep=config_template["field_delimiter"]
    )
    assert len(df.columns) == len(first_row)

    with pytest.raises(ValueError):
        batches.train_batch(99)

    with patch("gretel_synthetics.batch.train") as mock_train:
        batches.train_batch(5)
        arg = batches.batches[5].config
        mock_train.assert_called_with(arg, None)

    with patch("gretel_synthetics.batch.train") as mock_train:
        batches.train_all_batches()
        args = [b.config for b in batches.batches.values()]
        called_args = []
        for _, a, _ in mock_train.mock_calls:
            called_args.append(a[0])
        assert args == called_args

    with pytest.raises(ValueError):
        batches.set_batch_validator(5, "foo")

    with pytest.raises(ValueError):
        batches.set_batch_validator(99, simple_validator)

    batches.set_batch_validator(5, simple_validator)
    assert batches.batches[5].validator("1,2,3,4,5")
    # load validator back from disk
    batches.batches[5].load_validator_from_file()
    assert batches.batches[5].validator("1,2,3,4,5")

    # generate lines, simulating generation the max
    # valid line count
    def good():
        return GenText(
            text="1,2,3,4,5", valid=random.choice([None, True]), delimiter=","
        )

    def bad():
        return GenText(text="1,2,3", valid=False, delimiter=",")

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
        summary = batches.generate_batch_lines(5, max_invalid=1)
        assert summary.is_valid
        check_call = mock_gen.mock_calls[0]
        _, _, kwargs = check_call
        assert kwargs["max_invalid"] == 1

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
        summary = batches.generate_batch_lines(5)
        assert summary.is_valid

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [good(), good(), good(), bad(), bad(), good()]
        summary = batches.generate_batch_lines(5)
        assert not summary.is_valid

    with patch.object(batches, "generate_batch_lines") as mock_gen:
        batches.generate_all_batch_lines(max_invalid=15)
        assert mock_gen.call_count == len(batches.batches.keys())
        check_call = mock_gen.mock_calls[0]
        _, _, kwargs = check_call
        assert kwargs["max_invalid"] == 15

    # get synthetic df
    line = GenText(
        text="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15", valid=True, delimiter=","
    )
    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [line] * len(batches.batches[10].headers)
        batches.generate_batch_lines(10)

    assert len(batches.batches[10].synthetic_df) == len(batches.batches[10].headers)