コード例 #1
0
def test_record_factory_multi_batch(hr_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=hr_model_dir)

    factory = batcher.create_record_factory(num_lines=50, max_invalid=5000)

    df = factory.generate_all(output="df")
    assert len(df) == 50
コード例 #2
0
def test_generate_all_batch_lines_raise_on_failed(test_data):
    batches = DataFrameBatch(df=test_data, config=config_template)
    batches.create_training_data()

    batches.generate_batch_lines = Mock()
    batches.generate_all_batch_lines()
    _, args, kwargs = batches.generate_batch_lines.mock_calls[0]
    assert args == (0,)
    assert kwargs == {
        "max_invalid": MAX_INVALID,
        "raise_on_exceed_invalid": False,
        "num_lines": None,
        "parallelism": 0,
        "seed_fields": None,
    }

    batches.generate_batch_lines = Mock()
    batches.generate_all_batch_lines(
        max_invalid=10, raise_on_failed_batch=True, num_lines=5
    )
    _, args, kwargs = batches.generate_batch_lines.mock_calls[0]
    assert args == (0,)
    assert kwargs == {
        "max_invalid": 10,
        "raise_on_exceed_invalid": True,
        "num_lines": 5,
        "parallelism": 0,
        "seed_fields": None,
    }
コード例 #3
0
def test_train_batch_sp_tok(train_df, tmp_path):
    config = TensorFlowConfig(
        epochs=5,
        field_delimiter=",",
        checkpoint_dir=tmp_path,
        input_data_path=PATH_HOLDER,
        learning_rate=.01
    )
    tokenizer = SentencePieceTokenizerTrainer(
        vocab_size=10000,
        config=config
    )
    batcher = DataFrameBatch(
        df=train_df,
        config=config,
        tokenizer=tokenizer
    )
    batcher.create_training_data()
    batcher.train_all_batches()

    batcher.generate_all_batch_lines(num_lines=_tok_gen_count, max_invalid=5000)
    syn_df = batcher.batches_to_df()
    assert syn_df.shape[0] == _tok_gen_count

    # Generate with a RecordFactory
    factory = batcher.create_record_factory(num_lines=_tok_gen_count, max_invalid=5000)
    syn_df = factory.generate_all(output="df")
    assert syn_df.shape[0] == _tok_gen_count
    assert list(syn_df.columns) == list(train_df.columns)
    assert factory.summary["valid_count"] == _tok_gen_count
コード例 #4
0
def test_validate_seed_lines_ok_one_field(test_data):
    batches = DataFrameBatch(df=test_data,
                             config=config_template,
                             batch_size=3)

    check = batches._validate_batch_seed_values(batches.batches[0], {
        "ID_code": "foo",
    })
    assert check == "foo|"
コード例 #5
0
def test_record_factory_exc_fail_validator(safecast_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)

    def _validator(rec: dict):
        raise ValueError

    factory = batcher.create_record_factory(num_lines=10, validator=_validator, max_invalid=10)
    with pytest.raises(RuntimeError):
        list(factory)
コード例 #6
0
def test_record_factory_exhaust_iter(safecast_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)
    factory = batcher.create_record_factory(num_lines=10)
    records = list(factory)
    assert len(records) == 10
    assert factory._counter.valid_count == 10
    summary = factory.summary
    assert summary["num_lines"] == 10
    assert summary["max_invalid"] == 1000
    assert summary["valid_count"] == 10
コード例 #7
0
def test_batch_size(test_data):
    test_data = test_data.iloc[:, :60]
    batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15)
    assert batches.batch_size == 15
    assert [len(x) for x in batches.batch_headers] == [15, 15, 15, 15]

    test_data = test_data.iloc[:, :59]
    batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15)
    assert batches.batch_size == 15
    assert [len(x) for x in batches.batch_headers] == [15, 15, 15, 14]
コード例 #8
0
def test_validate_seed_lines_ok_full_size(test_data):
    batches = DataFrameBatch(df=test_data,
                             config=config_template,
                             batch_size=3)

    check = batches._validate_batch_seed_values(batches.batches[0], {
        "ID_code": "foo",
        "target": 0,
        "var_0": 33,
    })
    assert check == "foo|0|33|"
コード例 #9
0
def test_record_factory_multi_batch_seed_static(hr_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=hr_model_dir)

    factory = batcher.create_record_factory(num_lines=10,
                                            max_invalid=5000,
                                            seed_fields={"age": 5},
                                            validator=MyValidator())

    df = factory.generate_all(output="df")
    assert len(df) == 10
    assert df["age"].nunique() == 1
    assert df.iloc[0]["age"] == 5
コード例 #10
0
def test_validate_seed_lines_field_not_present(test_data):
    batches = DataFrameBatch(df=test_data,
                             config=config_template,
                             batch_size=3)

    with pytest.raises(RuntimeError) as err:
        batches._validate_batch_seed_values(batches.batches[0], {
            "ID_code": "foo",
            "target": 0,
            "var_1": 33,
        })
    assert "The header: var_0 is not in the seed" in str(err.value)
コード例 #11
0
def test_record_factory_simple_validator(safecast_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)

    def _validator(rec: dict):
        # ensure returning None and True works
        which = random.randint(0, 1)
        if which:
            assert float(rec["payload.loc_lat"])
        else:
            return True

    factory = batcher.create_record_factory(num_lines=10, validator=_validator)
    assert len(list(factory)) == 10
コード例 #12
0
def test_validate_seed_lines_too_many_fields(test_data):
    batches = DataFrameBatch(df=test_data,
                             config=config_template,
                             batch_size=3)

    with pytest.raises(RuntimeError) as err:
        batches._validate_batch_seed_values(batches.batches[0], {
            "ID_Code": "foo",
            "target": 0,
            "var_0": 33,
            "var_1": 33
        })
    assert "number of seed fields" in str(err.value)
コード例 #13
0
def test_record_factory_multi_batch_seed_list(hr_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=hr_model_dir)

    age_seeds = [{"age": i} for i in range(1, 51)]

    factory = batcher.create_record_factory(
        num_lines=50,  # doesn't matter w/ smart seed
        max_invalid=5000,
        seed_fields=age_seeds,
        validator=MyValidator())

    df = factory.generate_all(output="df")
    assert len(df) == 50
    assert df["age"].nunique() == 50
コード例 #14
0
def test_read_mode(mock_gen, test_data):
    writer = DataFrameBatch(df=test_data, config=config_template)
    writer.create_training_data()

    # missing checkpoint dir
    with pytest.raises(ValueError):
        DataFrameBatch(mode=READ)

    # bad checkpoint dir
    with pytest.raises(ValueError):
        DataFrameBatch(mode=READ, checkpoint_dir="bad_dir")

    # NOTE: normally saving the params is done during training,
    # but we do it here manually since we won't actually train
    for _, batch in writer.batches.items():
        batch.config.save_model_params()

    # checkpoint dir exists in config
    DataFrameBatch(config=config_template, mode=READ)

    # checkpoint dir as a kwarg
    reader = DataFrameBatch(checkpoint_dir=checkpoint_dir, mode=READ)

    write_batch = writer.batches[0]
    read_batch = reader.batches[0]

    assert write_batch.checkpoint_dir == read_batch.checkpoint_dir
    assert write_batch.headers == read_batch.headers
    assert asdict(write_batch.config) == asdict(read_batch.config)
    assert reader.master_header_list == writer.master_header_list
コード例 #15
0
def test_batches_to_df(test_data):
    batches = DataFrameBatch(df=pd.DataFrame([
        {"foo": "bar", "foo1": "bar1", "foo2": "bar2", "foo3": 3}]), config=config_template, batch_size=2)

    batches.batches[0].add_valid_data(
        GenText(text="baz|baz1", valid=True, delimiter="|")
    )
    batches.batches[1].add_valid_data(
        GenText(text="baz2|5", valid=True, delimiter="|")
    )

    check = batches.batches_to_df()
    assert list(check.columns) == ["foo", "foo1", "foo2", "foo3"]
    assert check.shape == (1, 4)
    assert [t.name for t in list(check.dtypes)] == ['object', 'object', 'object', 'int64']
コード例 #16
0
def test_record_factory_generate_all(safecast_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)

    def _validator(rec: dict):
        assert float(rec["payload.loc_lat"])

    factory = batcher.create_record_factory(num_lines=10, validator=_validator)
    next(factory)
    next(factory)

    # generate_all should reset our iterator for the full 10 records
    assert len(factory.generate_all()) == 10

    df = factory.generate_all(output="df")
    assert df.shape == (10, 16)
    assert str(df["payload.loc_lat"].dtype) == "float64"
コード例 #17
0
def test_record_factory_smart_seed(safecast_model_dir):
    seeds = [{"payload.service_handler": "i-051a2a353509414f0"},
             {"payload.service_handler": "i-051a2a353509414f1"},
             {"payload.service_handler": "i-051a2a353509414f2"},
             {"payload.service_handler": "i-051a2a353509414f3"}]

    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)
    factory = batcher.create_record_factory(
        num_lines=1000,
        seed_fields=seeds*1000
    )

    # list of seeds should reset num_lines
    assert factory.num_lines == len(seeds) * 1000

    for seed, record in zip(seeds, factory):
        assert seed["payload.service_handler"] == record["payload.service_handler"]
コード例 #18
0
def test_train_batch_sp_tok(train_df, tmp_path):
    config = TensorFlowConfig(epochs=5,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER,
                              learning_rate=.01)
    tokenizer = SentencePieceTokenizerTrainer(vocab_size=10000, config=config)
    batcher = DataFrameBatch(df=train_df, config=config, tokenizer=tokenizer)
    batcher.create_training_data()
    batcher.train_all_batches()

    batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000)
    syn_df = batcher.batches_to_df()
    assert syn_df.shape[0] == 100
コード例 #19
0
def test_record_factory_smart_seed_buffer(safecast_model_dir):
    seeds = [{"payload.service_handler": "i-051a2a353509414f0"},
             {"payload.service_handler": "i-051a2a353509414f1"},
             {"payload.service_handler": "i-051a2a353509414f2"},
             {"payload.service_handler": "i-051a2a353509414f3"}] * 2

    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)

    factory = batcher.create_record_factory(
        num_lines=100,  # doesn't matter w/ smart seed
        seed_fields=seeds,
        validator=MyValidator(),
        max_invalid=5000
    )

    df = factory.generate_all(output="df")
    assert len(df) == 8
    assert factory.summary["num_lines"] == 8
    assert factory.summary["valid_count"] == 8
コード例 #20
0
 def _generate(
     self, model_dir: Path, count: int, file_name: str, seed, validator
 ) -> str:
     batch_mode = is_model_dir_batch_mode(model_dir)
     if batch_mode:
         if seed is not None and not isinstance(seed, dict):
             raise TypeError("Seed must be a dict in batch mode")
         out_fname = f"{file_name}.csv"
         batcher = DataFrameBatch(mode="read", checkpoint_dir=str(model_dir))
         batcher.generate_all_batch_lines(
             num_lines=count,
             max_invalid=max(count, MAX_INVALID),
             parallelism=1,
             seed_fields=seed
         )
         out_df = batcher.batches_to_df()
         out_df.to_csv(out_fname, index=False)
         return out_fname
     else:
         out = []
         # The model data will be inside of a single directory when a simple model is used. If it
         # was archived correctly, there should only be a single directory inside the archive
         actual_dir = next(model_dir.glob("*"))
         config = config_from_model_dir(actual_dir)
         if seed is not None and not isinstance(seed, str):
             raise TypeError("seed must be a string")
         for data in generate_text(
             config,
             num_lines=count,
             line_validator=validator,
             max_invalid=max(count, MAX_INVALID),
             parallelism=1,
             start_string=seed
         ):
             if data.valid or data.valid is None:
                 out.append(data.text)
         out_fname = file_name + ".txt"
         with open(out_fname, "w") as fout:
             for line in out:
                 fout.write(line + "\n")
         return out_fname
コード例 #21
0
def test_epoch_callback(train_df, tmp_path):
    def epoch_callback(s: EpochState):
        with open(tmp_path / 'callback_dump.txt', 'a') as f:
            f.write(f'{s.epoch},{s.accuracy},{s.loss},{s.batch}\n')

    config = TensorFlowConfig(epochs=5,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER,
                              learning_rate=.01,
                              epoch_callback=epoch_callback)
    tokenizer = SentencePieceTokenizerTrainer(vocab_size=10000, config=config)
    batcher = DataFrameBatch(batch_size=4,
                             df=train_df,
                             config=config,
                             tokenizer=tokenizer)
    batcher.create_training_data()
    batcher.train_all_batches()
    with open(tmp_path / 'callback_dump.txt', 'r') as f:
        lines = f.readlines()
        assert len(lines) == 20
        for i, line in enumerate(lines):
            fields = line.strip().split(',')
            assert len(fields) == 4
            assert int(fields[0]) == i % 5
            assert int(fields[3]) == i // 5
            float(fields[1])
            float(fields[2])
    os.remove(tmp_path / 'callback_dump.txt')
コード例 #22
0
 def _fit_sample(self, data, metadata):
     config = {
         'max_lines': self.max_lines,
         'max_line_len': self.max_line_len,
         'epochs': self.epochs or data.shape[1] * 3,  # value recommended by Gretel
         'vocab_size': self.vocab_size,
         'gen_lines': self.gen_lines or data.shape[0],
         'dp': self.dp,
         'field_delimiter': self.field_delimiter,
         'overwrite': self.overwrite,
         'checkpoint_dir': self.checkpoint_dir
     }
     batcher = DataFrameBatch(df=data, config=config)
     batcher.create_training_data()
     batcher.train_all_batches()
     batcher.generate_all_batch_lines()
     synth_data = batcher.batches_to_df()
     return synth_data
コード例 #23
0
def test_record_factory_generate_all_with_callback(safecast_model_dir,
                                                   threading):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)

    def _validator(rec: dict):
        assert float(rec["payload.loc_lat"])

    factory = batcher.create_record_factory(num_lines=1000,
                                            validator=_validator,
                                            invalid_cache_size=5)

    callback_fn = Mock()

    df = factory.generate_all(output="df",
                              callback=callback_fn,
                              callback_interval=1,
                              callback_threading=threading)
    assert df.shape == (1000, 16)

    # assuming we get at least 5 bad records
    assert len(factory.invalid_cache) == 5
    assert str(df["payload.loc_lat"].dtype) == "float64"

    assert callback_fn.call_count >= 2
    # at least 1 call during generation and another one with final update
    assert callback_fn.call_count < 1000, "Progress update should be only called periodically"

    args, _ = callback_fn.call_args  # pylint: disable=unpacking-non-sequence
    last_update: GenerationProgress = args[0]
    assert last_update.current_valid_count == 1000
    assert last_update.completion_percent == 100

    # calculate sum from all updates
    valid_total_count = 0
    for call_args in callback_fn.call_args_list:
        args, _ = call_args
        valid_total_count += args[0].new_valid_count

    assert valid_total_count == 1000
コード例 #24
0
def test_train_batch_char_tok(train_df, tmp_path):
    config = TensorFlowConfig(epochs=5,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER,
                              learning_rate=.01)
    batcher = DataFrameBatch(df=train_df,
                             config=config,
                             tokenizer=CharTokenizerTrainer(config=config))
    batcher.create_training_data()
    batcher.train_all_batches()

    tok_params = json.loads(
        open(tmp_path / "batch_0" /
             BaseTokenizerTrainer.settings_fname).read())
    assert tok_params["tokenizer_type"] == CharTokenizerTrainer.__name__

    batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000)
    syn_df = batcher.batches_to_df()
    assert syn_df.shape[0] == 100
コード例 #25
0
def test_train_small_df(train_df, tmp_path):
    small_df = train_df.sample(n=50)
    config = TensorFlowConfig(epochs=5,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER)
    batcher = DataFrameBatch(df=small_df, config=config)
    batcher.create_training_data()
    with pytest.raises(RuntimeError) as excinfo:
        batcher.train_all_batches()
    assert "Model training failed" in str(excinfo.value)
コード例 #26
0
def test_generate_batch_lines_raise_on_exceed(test_data):
    batches = DataFrameBatch(df=test_data, config=config_template)
    batches.create_training_data()

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.side_effect = TooManyInvalidError()
        assert not batches.generate_batch_lines(0)

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.side_effect = TooManyInvalidError()
        with pytest.raises(TooManyInvalidError):
            assert not batches.generate_batch_lines(0, raise_on_exceed_invalid=True)
コード例 #27
0
def test_train_batch_sp(train_df, tmp_path):
    config = TensorFlowConfig(epochs=1,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER)
    batcher = DataFrameBatch(df=train_df, config=config)
    batcher.create_training_data()
    batcher.train_all_batches()

    model_params = json.loads(
        open(tmp_path / "batch_0" / const.MODEL_PARAMS).read())
    assert model_params[const.MODEL_TYPE] == TensorFlowConfig.__name__

    tok_params = json.loads(
        open(tmp_path / "batch_0" /
             BaseTokenizerTrainer.settings_fname).read())
    assert tok_params[
        "tokenizer_type"] == SentencePieceTokenizerTrainer.__name__
コード例 #28
0
def test_train_batch_sp_regression(train_df, tmp_path):
    """Batch mode with default SentencePiece tokenizer. Using the backwards
    compat mode for <= 0.14.0.
    """
    config = {"epochs": 1, "field_delimiter": ",", "checkpoint_dir": tmp_path}
    batcher = DataFrameBatch(df=train_df, config=config)
    batcher.create_training_data()
    batcher.train_all_batches()

    model_params = json.loads(
        open(tmp_path / "batch_0" / const.MODEL_PARAMS).read())
    assert model_params[const.MODEL_TYPE] == TensorFlowConfig.__name__

    tok_params = json.loads(
        open(tmp_path / "batch_0" /
             BaseTokenizerTrainer.settings_fname).read())
    assert tok_params[
        "tokenizer_type"] == SentencePieceTokenizerTrainer.__name__
コード例 #29
0
def test_init(test_data):
    with pytest.raises(ValueError):
        DataFrameBatch(df="nope", config=config_template)

    # should create the dir structure based on auto
    # batch sizing
    batches = DataFrameBatch(df=test_data, config=config_template, batch_size=15)
    first_row = [
        "ID_code",
        "target",
        "var_0",
        "var_1",
        "var_2",
        "var_3",
        "var_4",
        "var_5",
        "var_6",
        "var_7",
        "var_8",
        "var_9",
        "var_10",
        "var_11",
        "var_12",
    ]
    assert batches.batches[0].headers == first_row
    assert len(batches.batches.keys()) == 14
    for i, batch in batches.batches.items():
        assert Path(batch.checkpoint_dir).is_dir()
        assert Path(batch.checkpoint_dir).name == f"batch_{i}"

    orig_headers = json.loads(
        open(Path(config_template["checkpoint_dir"]) / ORIG_HEADERS).read()
    )
    assert list(set(orig_headers)) == list(set(test_data.columns))

    batches.create_training_data()
    df = pd.read_csv(
        batches.batches[0].input_data_path, sep=config_template["field_delimiter"]
    )
    assert len(df.columns) == len(first_row)

    with pytest.raises(ValueError):
        batches.train_batch(99)

    with patch("gretel_synthetics.batch.train") as mock_train:
        batches.train_batch(5)
        arg = batches.batches[5].config
        mock_train.assert_called_with(arg, None)

    with patch("gretel_synthetics.batch.train") as mock_train:
        batches.train_all_batches()
        args = [b.config for b in batches.batches.values()]
        called_args = []
        for _, a, _ in mock_train.mock_calls:
            called_args.append(a[0])
        assert args == called_args

    with pytest.raises(ValueError):
        batches.set_batch_validator(5, "foo")

    with pytest.raises(ValueError):
        batches.set_batch_validator(99, simple_validator)

    batches.set_batch_validator(5, simple_validator)
    assert batches.batches[5].validator("1,2,3,4,5")
    # load validator back from disk
    batches.batches[5].load_validator_from_file()
    assert batches.batches[5].validator("1,2,3,4,5")

    # generate lines, simulating generation the max
    # valid line count
    def good():
        return GenText(
            text="1,2,3,4,5", valid=random.choice([None, True]), delimiter=","
        )

    def bad():
        return GenText(text="1,2,3", valid=False, delimiter=",")

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
        summary = batches.generate_batch_lines(5, max_invalid=1)
        assert summary.is_valid
        check_call = mock_gen.mock_calls[0]
        _, _, kwargs = check_call
        assert kwargs["max_invalid"] == 1

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()]
        summary = batches.generate_batch_lines(5)
        assert summary.is_valid

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [good(), good(), good(), bad(), bad(), good()]
        summary = batches.generate_batch_lines(5)
        assert not summary.is_valid

    with patch.object(batches, "generate_batch_lines") as mock_gen:
        batches.generate_all_batch_lines(max_invalid=15)
        assert mock_gen.call_count == len(batches.batches.keys())
        check_call = mock_gen.mock_calls[0]
        _, _, kwargs = check_call
        assert kwargs["max_invalid"] == 15

    # get synthetic df
    line = GenText(
        text="1,2,3,4,5,6,7,8,9,10,11,12,13,14,15", valid=True, delimiter=","
    )
    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.return_value = [line] * len(batches.batches[10].headers)
        batches.generate_batch_lines(10)

    assert len(batches.batches[10].synthetic_df) == len(batches.batches[10].headers)
コード例 #30
0
def test_missing_delim(test_data):
    config = deepcopy(config_template)
    config.pop("field_delimiter")
    with pytest.raises(ValueError):
        DataFrameBatch(df=test_data, config=config)