コード例 #1
0
def test_generate_all_batch_lines_raise_on_failed(test_data):
    batches = DataFrameBatch(df=test_data, config=config_template)
    batches.create_training_data()

    batches.generate_batch_lines = Mock()
    batches.generate_all_batch_lines()
    _, args, kwargs = batches.generate_batch_lines.mock_calls[0]
    assert args == (0,)
    assert kwargs == {
        "max_invalid": MAX_INVALID,
        "raise_on_exceed_invalid": False,
        "num_lines": None,
        "parallelism": 0,
        "seed_fields": None,
    }

    batches.generate_batch_lines = Mock()
    batches.generate_all_batch_lines(
        max_invalid=10, raise_on_failed_batch=True, num_lines=5
    )
    _, args, kwargs = batches.generate_batch_lines.mock_calls[0]
    assert args == (0,)
    assert kwargs == {
        "max_invalid": 10,
        "raise_on_exceed_invalid": True,
        "num_lines": 5,
        "parallelism": 0,
        "seed_fields": None,
    }
コード例 #2
0
def test_train_batch_sp_tok(train_df, tmp_path):
    config = TensorFlowConfig(
        epochs=5,
        field_delimiter=",",
        checkpoint_dir=tmp_path,
        input_data_path=PATH_HOLDER,
        learning_rate=.01
    )
    tokenizer = SentencePieceTokenizerTrainer(
        vocab_size=10000,
        config=config
    )
    batcher = DataFrameBatch(
        df=train_df,
        config=config,
        tokenizer=tokenizer
    )
    batcher.create_training_data()
    batcher.train_all_batches()

    batcher.generate_all_batch_lines(num_lines=_tok_gen_count, max_invalid=5000)
    syn_df = batcher.batches_to_df()
    assert syn_df.shape[0] == _tok_gen_count

    # Generate with a RecordFactory
    factory = batcher.create_record_factory(num_lines=_tok_gen_count, max_invalid=5000)
    syn_df = factory.generate_all(output="df")
    assert syn_df.shape[0] == _tok_gen_count
    assert list(syn_df.columns) == list(train_df.columns)
    assert factory.summary["valid_count"] == _tok_gen_count
コード例 #3
0
def test_train_batch_sp_tok(train_df, tmp_path):
    config = TensorFlowConfig(epochs=5,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER,
                              learning_rate=.01)
    tokenizer = SentencePieceTokenizerTrainer(vocab_size=10000, config=config)
    batcher = DataFrameBatch(df=train_df, config=config, tokenizer=tokenizer)
    batcher.create_training_data()
    batcher.train_all_batches()

    batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000)
    syn_df = batcher.batches_to_df()
    assert syn_df.shape[0] == 100
コード例 #4
0
 def _fit_sample(self, data, metadata):
     config = {
         'max_lines': self.max_lines,
         'max_line_len': self.max_line_len,
         'epochs': self.epochs or data.shape[1] * 3,  # value recommended by Gretel
         'vocab_size': self.vocab_size,
         'gen_lines': self.gen_lines or data.shape[0],
         'dp': self.dp,
         'field_delimiter': self.field_delimiter,
         'overwrite': self.overwrite,
         'checkpoint_dir': self.checkpoint_dir
     }
     batcher = DataFrameBatch(df=data, config=config)
     batcher.create_training_data()
     batcher.train_all_batches()
     batcher.generate_all_batch_lines()
     synth_data = batcher.batches_to_df()
     return synth_data
コード例 #5
0
def test_train_batch_char_tok(train_df, tmp_path):
    config = TensorFlowConfig(epochs=5,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER,
                              learning_rate=.01)
    batcher = DataFrameBatch(df=train_df,
                             config=config,
                             tokenizer=CharTokenizerTrainer(config=config))
    batcher.create_training_data()
    batcher.train_all_batches()

    tok_params = json.loads(
        open(tmp_path / "batch_0" /
             BaseTokenizerTrainer.settings_fname).read())
    assert tok_params["tokenizer_type"] == CharTokenizerTrainer.__name__

    batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000)
    syn_df = batcher.batches_to_df()
    assert syn_df.shape[0] == 100
コード例 #6
0
 def _generate(
     self, model_dir: Path, count: int, file_name: str, seed, validator
 ) -> str:
     batch_mode = is_model_dir_batch_mode(model_dir)
     if batch_mode:
         if seed is not None and not isinstance(seed, dict):
             raise TypeError("Seed must be a dict in batch mode")
         out_fname = f"{file_name}.csv"
         batcher = DataFrameBatch(mode="read", checkpoint_dir=str(model_dir))
         batcher.generate_all_batch_lines(
             num_lines=count,
             max_invalid=max(count, MAX_INVALID),
             parallelism=1,
             seed_fields=seed
         )
         out_df = batcher.batches_to_df()
         out_df.to_csv(out_fname, index=False)
         return out_fname
     else:
         out = []
         # The model data will be inside of a single directory when a simple model is used. If it
         # was archived correctly, there should only be a single directory inside the archive
         actual_dir = next(model_dir.glob("*"))
         config = config_from_model_dir(actual_dir)
         if seed is not None and not isinstance(seed, str):
             raise TypeError("seed must be a string")
         for data in generate_text(
             config,
             num_lines=count,
             line_validator=validator,
             max_invalid=max(count, MAX_INVALID),
             parallelism=1,
             start_string=seed
         ):
             if data.valid or data.valid is None:
                 out.append(data.text)
         out_fname = file_name + ".txt"
         with open(out_fname, "w") as fout:
             for line in out:
                 fout.write(line + "\n")
         return out_fname
コード例 #7
0
def test_init(test_data):
    with pytest.raises(ValueError):
        DataFrameBatch(df="nope", config=config_template)

    # should create the dir structure based on auto
    # batch sizing
    batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15)
    first_row = [
        "ID_code",
        "target",
        "var_0",
        "var_1",
        "var_2",
        "var_3",
        "var_4",
        "var_5",
        "var_6",
        "var_7",
        "var_8",
        "var_9",
        "var_10",
        "var_11",
        "var_12",
    ]
    assert batches.batches[0].headers == first_row
    assert len(batches.batches.keys()) == 14
    for i, batch in batches.batches.items():
        assert Path(batch.checkpoint_dir).is_dir()
        assert Path(batch.checkpoint_dir).name == f"batch_{i}"

    orig_headers = json.loads(
        open(Path(config_template["checkpoint_dir"]) / ORIG_HEADERS).read()
    )
    assert list(set(orig_headers)) == list(set(test_data.columns))

    batches.create_training_data()
    df = pd.read_csv(
        batches.batches[0].input_data_path, sep=config_template["field_delimiter"]
    )
    assert len(df.columns) == len(first_row)

    with pytest.raises(ValueError):
        batches.train_batch(99)

    with patch("gretel_synthetics.batch.train") as mock_train:
        batches.train_batch(5)
        arg = batches.batches[5].config
        mock_train.assert_called_with(arg, None)

    with patch("gretel_synthetics.batch.train") as mock_train:
        batches.train_all_batches()
        args = [b.config for b in batches.batches.values()]
        called_args = []
        for _, a, _ in mock_train.mock_calls:
            called_args.append(a[0])
        assert args == called_args

    with pytest.raises(ValueError):
        batches.set_batch_validator(5, "foo")

    with pytest.raises(ValueError):
        batches.set_batch_validator(99, simple_validator)

    batches.set_batch_validator(5, simple_validator)
    assert batches.batches[5].validator("1,2,3,4,5")
    # load validator back from disk
    batches.batches[5].load_validator_from_file()
    assert batches.batches[5].validator("1,2,3,4,5")

    # generate lines, simulating generation the max
    # valid line count
    def good():
        return GenText(
            text="1,2,3,4,5", valid=random.choice([None, True]), delimiter=","
        )

    def bad():
        return GenText(text="1,2,3", valid=False, delimiter=",")

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
        summary = batches.generate_batch_lines(5, max_invalid=1)
        assert summary.is_valid
        check_call = mock_gen.mock_calls[0]
        _, _, kwargs = check_call
        assert kwargs["max_invalid"] == 1

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
        summary = batches.generate_batch_lines(5)
        assert summary.is_valid

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [good(), good(), good(), bad(), bad(), good()]
        summary = batches.generate_batch_lines(5)
        assert not summary.is_valid

    with patch.object(batches, "generate_batch_lines") as mock_gen:
        batches.generate_all_batch_lines(max_invalid=15)
        assert mock_gen.call_count == len(batches.batches.keys())
        check_call = mock_gen.mock_calls[0]
        _, _, kwargs = check_call
        assert kwargs["max_invalid"] == 15

    # get synthetic df
    line = GenText(
        text="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15", valid=True, delimiter=","
    )
    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [line] * len(batches.batches[10].headers)
        batches.generate_batch_lines(10)

    assert len(batches.batches[10].synthetic_df) == len(batches.batches[10].headers)