Exemple #1
0
def test_record_factory_multi_batch(hr_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=hr_model_dir)

    factory = batcher.create_record_factory(num_lines=50, max_invalid=5000)

    df = factory.generate_all(output="df")
    assert len(df) == 50
def test_train_batch_sp_tok(train_df, tmp_path):
    config = TensorFlowConfig(
        epochs=5,
        field_delimiter=",",
        checkpoint_dir=tmp_path,
        input_data_path=PATH_HOLDER,
        learning_rate=.01
    )
    tokenizer = SentencePieceTokenizerTrainer(
        vocab_size=10000,
        config=config
    )
    batcher = DataFrameBatch(
        df=train_df,
        config=config,
        tokenizer=tokenizer
    )
    batcher.create_training_data()
    batcher.train_all_batches()

    batcher.generate_all_batch_lines(num_lines=_tok_gen_count, max_invalid=5000)
    syn_df = batcher.batches_to_df()
    assert syn_df.shape[0] == _tok_gen_count

    # Generate with a RecordFactory
    factory = batcher.create_record_factory(num_lines=_tok_gen_count, max_invalid=5000)
    syn_df = factory.generate_all(output="df")
    assert syn_df.shape[0] == _tok_gen_count
    assert list(syn_df.columns) == list(train_df.columns)
    assert factory.summary["valid_count"] == _tok_gen_count
Exemple #3
0
def test_record_factory_exc_fail_validator(safecast_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)

    def _validator(rec: dict):
        raise ValueError

    factory = batcher.create_record_factory(num_lines=10, validator=_validator, max_invalid=10)
    with pytest.raises(RuntimeError):
        list(factory)
Exemple #4
0
def test_record_factory_exhaust_iter(safecast_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)
    factory = batcher.create_record_factory(num_lines=10)
    records = list(factory)
    assert len(records) == 10
    assert factory._counter.valid_count == 10
    summary = factory.summary
    assert summary["num_lines"] == 10
    assert summary["max_invalid"] == 1000
    assert summary["valid_count"] == 10
Exemple #5
0
def test_record_factory_multi_batch_seed_static(hr_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=hr_model_dir)

    factory = batcher.create_record_factory(num_lines=10,
                                            max_invalid=5000,
                                            seed_fields={"age": 5},
                                            validator=MyValidator())

    df = factory.generate_all(output="df")
    assert len(df) == 10
    assert df["age"].nunique() == 1
    assert df.iloc[0]["age"] == 5
Exemple #6
0
def test_record_factory_simple_validator(safecast_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)

    def _validator(rec: dict):
        # ensure returning None and True works
        which = random.randint(0, 1)
        if which:
            assert float(rec["payload.loc_lat"])
        else:
            return True

    factory = batcher.create_record_factory(num_lines=10, validator=_validator)
    assert len(list(factory)) == 10
Exemple #7
0
def test_record_factory_multi_batch_seed_list(hr_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=hr_model_dir)

    age_seeds = [{"age": i} for i in range(1, 51)]

    factory = batcher.create_record_factory(
        num_lines=50,  # doesn't matter w/ smart seed
        max_invalid=5000,
        seed_fields=age_seeds,
        validator=MyValidator())

    df = factory.generate_all(output="df")
    assert len(df) == 50
    assert df["age"].nunique() == 50
Exemple #8
0
def test_record_factory_generate_all(safecast_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)

    def _validator(rec: dict):
        assert float(rec["payload.loc_lat"])

    factory = batcher.create_record_factory(num_lines=10, validator=_validator)
    next(factory)
    next(factory)

    # generate_all should reset our iterator for the full 10 records
    assert len(factory.generate_all()) == 10

    df = factory.generate_all(output="df")
    assert df.shape == (10, 16)
    assert str(df["payload.loc_lat"].dtype) == "float64"
Exemple #9
0
def test_record_factory_smart_seed(safecast_model_dir):
    seeds = [{"payload.service_handler": "i-051a2a353509414f0"},
             {"payload.service_handler": "i-051a2a353509414f1"},
             {"payload.service_handler": "i-051a2a353509414f2"},
             {"payload.service_handler": "i-051a2a353509414f3"}]

    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)
    factory = batcher.create_record_factory(
        num_lines=1000,
        seed_fields=seeds*1000
    )

    # list of seeds should reset num_lines
    assert factory.num_lines == len(seeds) * 1000

    for seed, record in zip(seeds, factory):
        assert seed["payload.service_handler"] == record["payload.service_handler"]
Exemple #10
0
def test_record_factory_smart_seed_buffer(safecast_model_dir):
    seeds = [{"payload.service_handler": "i-051a2a353509414f0"},
             {"payload.service_handler": "i-051a2a353509414f1"},
             {"payload.service_handler": "i-051a2a353509414f2"},
             {"payload.service_handler": "i-051a2a353509414f3"}] * 2

    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)

    factory = batcher.create_record_factory(
        num_lines=100,  # doesn't matter w/ smart seed
        seed_fields=seeds,
        validator=MyValidator(),
        max_invalid=5000
    )

    df = factory.generate_all(output="df")
    assert len(df) == 8
    assert factory.summary["num_lines"] == 8
    assert factory.summary["valid_count"] == 8
Exemple #11
0
def test_record_factory_generate_all_with_callback(safecast_model_dir,
                                                   threading):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)

    def _validator(rec: dict):
        assert float(rec["payload.loc_lat"])

    factory = batcher.create_record_factory(num_lines=1000,
                                            validator=_validator,
                                            invalid_cache_size=5)

    callback_fn = Mock()

    df = factory.generate_all(output="df",
                              callback=callback_fn,
                              callback_interval=1,
                              callback_threading=threading)
    assert df.shape == (1000, 16)

    # assuming we get at least 5 bad records
    assert len(factory.invalid_cache) == 5
    assert str(df["payload.loc_lat"].dtype) == "float64"

    assert callback_fn.call_count >= 2
    # at least 1 call during generation and another one with final update
    assert callback_fn.call_count < 1000, "Progress update should be only called periodically"

    args, _ = callback_fn.call_args  # pylint: disable=unpacking-non-sequence
    last_update: GenerationProgress = args[0]
    assert last_update.current_valid_count == 1000
    assert last_update.completion_percent == 100

    # calculate sum from all updates
    valid_total_count = 0
    for call_args in callback_fn.call_args_list:
        args, _ = call_args
        valid_total_count += args[0].new_valid_count

    assert valid_total_count == 1000
Exemple #12
0
def test_record_factory_single_line(safecast_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)
    factory = batcher.create_record_factory(num_lines=10)
    record = next(factory)
    assert "payload.service_handler" in record
    assert factory._counter.valid_count == 1
Exemple #13
0
def test_record_factory_bad_validator(safecast_model_dir):
    batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir)
    with pytest.raises(ValueError) as err:
        batcher.create_record_factory(num_lines=10, validator="foo")
    assert "must be callable" in str(err)