コード例 #1
0
def test_train_batch_sp_tok(train_df, tmp_path):
    config = TensorFlowConfig(
        epochs=5,
        field_delimiter=",",
        checkpoint_dir=tmp_path,
        input_data_path=PATH_HOLDER,
        learning_rate=.01
    )
    tokenizer = SentencePieceTokenizerTrainer(
        vocab_size=10000,
        config=config
    )
    batcher = DataFrameBatch(
        df=train_df,
        config=config,
        tokenizer=tokenizer
    )
    batcher.create_training_data()
    batcher.train_all_batches()

    batcher.generate_all_batch_lines(num_lines=_tok_gen_count, max_invalid=5000)
    syn_df = batcher.batches_to_df()
    assert syn_df.shape[0] == _tok_gen_count

    # Generate with a RecordFactory
    factory = batcher.create_record_factory(num_lines=_tok_gen_count, max_invalid=5000)
    syn_df = factory.generate_all(output="df")
    assert syn_df.shape[0] == _tok_gen_count
    assert list(syn_df.columns) == list(train_df.columns)
    assert factory.summary["valid_count"] == _tok_gen_count
コード例 #2
0
def test_read_mode(mock_gen, test_data):
    writer = DataFrameBatch(df=test_data, config=config_template)
    writer.create_training_data()

    # missing checkpoint dir
    with pytest.raises(ValueError):
        DataFrameBatch(mode=READ)

    # bad checkpoint dir
    with pytest.raises(ValueError):
        DataFrameBatch(mode=READ, checkpoint_dir="bad_dir")

    # NOTE: normally saving the params is done during training,
    # but we do it here manually since we won't actually train
    for _, batch in writer.batches.items():
        batch.config.save_model_params()

    # checkpoint dir exists in config
    DataFrameBatch(config=config_template, mode=READ)

    # checkpoint dir as a kwarg
    reader = DataFrameBatch(checkpoint_dir=checkpoint_dir, mode=READ)

    write_batch = writer.batches[0]
    read_batch = reader.batches[0]

    assert write_batch.checkpoint_dir == read_batch.checkpoint_dir
    assert write_batch.headers == read_batch.headers
    assert asdict(write_batch.config) == asdict(read_batch.config)
    assert reader.master_header_list == writer.master_header_list
コード例 #3
0
def test_epoch_callback(train_df, tmp_path):
    def epoch_callback(s: EpochState):
        with open(tmp_path / 'callback_dump.txt', 'a') as f:
            f.write(f'{s.epoch},{s.accuracy},{s.loss},{s.batch}\n')

    config = TensorFlowConfig(epochs=5,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER,
                              learning_rate=.01,
                              epoch_callback=epoch_callback)
    tokenizer = SentencePieceTokenizerTrainer(vocab_size=10000, config=config)
    batcher = DataFrameBatch(batch_size=4,
                             df=train_df,
                             config=config,
                             tokenizer=tokenizer)
    batcher.create_training_data()
    batcher.train_all_batches()
    with open(tmp_path / 'callback_dump.txt', 'r') as f:
        lines = f.readlines()
        assert len(lines) == 20
        for i, line in enumerate(lines):
            fields = line.strip().split(',')
            assert len(fields) == 4
            assert int(fields[0]) == i % 5
            assert int(fields[3]) == i // 5
            float(fields[1])
            float(fields[2])
    os.remove(tmp_path / 'callback_dump.txt')
コード例 #4
0
def test_generate_all_batch_lines_raise_on_failed(test_data):
    batches = DataFrameBatch(df=test_data, config=config_template)
    batches.create_training_data()

    batches.generate_batch_lines = Mock()
    batches.generate_all_batch_lines()
    _, args, kwargs = batches.generate_batch_lines.mock_calls[0]
    assert args == (0,)
    assert kwargs == {
        "max_invalid": MAX_INVALID,
        "raise_on_exceed_invalid": False,
        "num_lines": None,
        "parallelism": 0,
        "seed_fields": None,
    }

    batches.generate_batch_lines = Mock()
    batches.generate_all_batch_lines(
        max_invalid=10, raise_on_failed_batch=True, num_lines=5
    )
    _, args, kwargs = batches.generate_batch_lines.mock_calls[0]
    assert args == (0,)
    assert kwargs == {
        "max_invalid": 10,
        "raise_on_exceed_invalid": True,
        "num_lines": 5,
        "parallelism": 0,
        "seed_fields": None,
    }
コード例 #5
0
def test_train_small_df(train_df, tmp_path):
    small_df = train_df.sample(n=50)
    config = TensorFlowConfig(epochs=5,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER)
    batcher = DataFrameBatch(df=small_df, config=config)
    batcher.create_training_data()
    with pytest.raises(RuntimeError) as excinfo:
        batcher.train_all_batches()
    assert "Model training failed" in str(excinfo.value)
コード例 #6
0
def test_generate_batch_lines_raise_on_exceed(test_data):
    batches = DataFrameBatch(df=test_data, config=config_template)
    batches.create_training_data()

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.side_effect = TooManyInvalidError()
        assert not batches.generate_batch_lines(0)

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.side_effect = TooManyInvalidError()
        with pytest.raises(TooManyInvalidError):
            assert not batches.generate_batch_lines(0, raise_on_exceed_invalid=True)
コード例 #7
0
def test_train_batch_sp_tok(train_df, tmp_path):
    config = TensorFlowConfig(epochs=5,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER,
                              learning_rate=.01)
    tokenizer = SentencePieceTokenizerTrainer(vocab_size=10000, config=config)
    batcher = DataFrameBatch(df=train_df, config=config, tokenizer=tokenizer)
    batcher.create_training_data()
    batcher.train_all_batches()

    batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000)
    syn_df = batcher.batches_to_df()
    assert syn_df.shape[0] == 100
コード例 #8
0
def test_train_batch_sp(train_df, tmp_path):
    config = TensorFlowConfig(epochs=1,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER)
    batcher = DataFrameBatch(df=train_df, config=config)
    batcher.create_training_data()
    batcher.train_all_batches()

    model_params = json.loads(
        open(tmp_path / "batch_0" / const.MODEL_PARAMS).read())
    assert model_params[const.MODEL_TYPE] == TensorFlowConfig.__name__

    tok_params = json.loads(
        open(tmp_path / "batch_0" /
             BaseTokenizerTrainer.settings_fname).read())
    assert tok_params[
        "tokenizer_type"] == SentencePieceTokenizerTrainer.__name__
コード例 #9
0
def test_train_batch_sp_regression(train_df, tmp_path):
    """Batch mode with default SentencePiece tokenizer. Using the backwards
    compat mode for <= 0.14.0.
    """
    config = {"epochs": 1, "field_delimiter": ",", "checkpoint_dir": tmp_path}
    batcher = DataFrameBatch(df=train_df, config=config)
    batcher.create_training_data()
    batcher.train_all_batches()

    model_params = json.loads(
        open(tmp_path / "batch_0" / const.MODEL_PARAMS).read())
    assert model_params[const.MODEL_TYPE] == TensorFlowConfig.__name__

    tok_params = json.loads(
        open(tmp_path / "batch_0" /
             BaseTokenizerTrainer.settings_fname).read())
    assert tok_params[
        "tokenizer_type"] == SentencePieceTokenizerTrainer.__name__
コード例 #10
0
 def _fit_sample(self, data, metadata):
     config = {
         'max_lines': self.max_lines,
         'max_line_len': self.max_line_len,
         'epochs': self.epochs or data.shape[1] * 3,  # value recommended by Gretel
         'vocab_size': self.vocab_size,
         'gen_lines': self.gen_lines or data.shape[0],
         'dp': self.dp,
         'field_delimiter': self.field_delimiter,
         'overwrite': self.overwrite,
         'checkpoint_dir': self.checkpoint_dir
     }
     batcher = DataFrameBatch(df=data, config=config)
     batcher.create_training_data()
     batcher.train_all_batches()
     batcher.generate_all_batch_lines()
     synth_data = batcher.batches_to_df()
     return synth_data
コード例 #11
0
def test_train_batch_char_tok(train_df, tmp_path):
    config = TensorFlowConfig(epochs=5,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER,
                              learning_rate=.01)
    batcher = DataFrameBatch(df=train_df,
                             config=config,
                             tokenizer=CharTokenizerTrainer(config=config))
    batcher.create_training_data()
    batcher.train_all_batches()

    tok_params = json.loads(
        open(tmp_path / "batch_0" /
             BaseTokenizerTrainer.settings_fname).read())
    assert tok_params["tokenizer_type"] == CharTokenizerTrainer.__name__

    batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000)
    syn_df = batcher.batches_to_df()
    assert syn_df.shape[0] == 100
コード例 #12
0
def test_init(test_data):
    with pytest.raises(ValueError):
        DataFrameBatch(df="nope", config=config_template)

    # should create the dir structure based on auto
    # batch sizing
    batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15)
    first_row = [
        "ID_code",
        "target",
        "var_0",
        "var_1",
        "var_2",
        "var_3",
        "var_4",
        "var_5",
        "var_6",
        "var_7",
        "var_8",
        "var_9",
        "var_10",
        "var_11",
        "var_12",
    ]
    assert batches.batches[0].headers == first_row
    assert len(batches.batches.keys()) == 14
    for i, batch in batches.batches.items():
        assert Path(batch.checkpoint_dir).is_dir()
        assert Path(batch.checkpoint_dir).name == f"batch_{i}"

    orig_headers = json.loads(
        open(Path(config_template["checkpoint_dir"]) / ORIG_HEADERS).read()
    )
    assert list(set(orig_headers)) == list(set(test_data.columns))

    batches.create_training_data()
    df = pd.read_csv(
        batches.batches[0].input_data_path, sep=config_template["field_delimiter"]
    )
    assert len(df.columns) == len(first_row)

    with pytest.raises(ValueError):
        batches.train_batch(99)

    with patch("gretel_synthetics.batch.train") as mock_train:
        batches.train_batch(5)
        arg = batches.batches[5].config
        mock_train.assert_called_with(arg, None)

    with patch("gretel_synthetics.batch.train") as mock_train:
        batches.train_all_batches()
        args = [b.config for b in batches.batches.values()]
        called_args = []
        for _, a, _ in mock_train.mock_calls:
            called_args.append(a[0])
        assert args == called_args

    with pytest.raises(ValueError):
        batches.set_batch_validator(5, "foo")

    with pytest.raises(ValueError):
        batches.set_batch_validator(99, simple_validator)

    batches.set_batch_validator(5, simple_validator)
    assert batches.batches[5].validator("1,2,3,4,5")
    # load validator back from disk
    batches.batches[5].load_validator_from_file()
    assert batches.batches[5].validator("1,2,3,4,5")

    # generate lines, simulating generation the max
    # valid line count
    def good():
        return GenText(
            text="1,2,3,4,5", valid=random.choice([None, True]), delimiter=","
        )

    def bad():
        return GenText(text="1,2,3", valid=False, delimiter=",")

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
        summary = batches.generate_batch_lines(5, max_invalid=1)
        assert summary.is_valid
        check_call = mock_gen.mock_calls[0]
        _, _, kwargs = check_call
        assert kwargs["max_invalid"] == 1

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
        summary = batches.generate_batch_lines(5)
        assert summary.is_valid

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [good(), good(), good(), bad(), bad(), good()]
        summary = batches.generate_batch_lines(5)
        assert not summary.is_valid

    with patch.object(batches, "generate_batch_lines") as mock_gen:
        batches.generate_all_batch_lines(max_invalid=15)
        assert mock_gen.call_count == len(batches.batches.keys())
        check_call = mock_gen.mock_calls[0]
        _, _, kwargs = check_call
        assert kwargs["max_invalid"] == 15

    # get synthetic df
    line = GenText(
        text="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15", valid=True, delimiter=","
    )
    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [line] * len(batches.batches[10].headers)
        batches.generate_batch_lines(10)

    assert len(batches.batches[10].synthetic_df) == len(batches.batches[10].headers)