def test_record_factory_multi_batch(hr_model_dir): batcher = DataFrameBatch(mode="read", checkpoint_dir=hr_model_dir) factory = batcher.create_record_factory(num_lines=50, max_invalid=5000) df = factory.generate_all(output="df") assert len(df) == 50
def test_train_batch_sp_tok(train_df, tmp_path): config = TensorFlowConfig( epochs=5, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER, learning_rate=.01 ) tokenizer = SentencePieceTokenizerTrainer( vocab_size=10000, config=config ) batcher = DataFrameBatch( df=train_df, config=config, tokenizer=tokenizer ) batcher.create_training_data() batcher.train_all_batches() batcher.generate_all_batch_lines(num_lines=_tok_gen_count, max_invalid=5000) syn_df = batcher.batches_to_df() assert syn_df.shape[0] == _tok_gen_count # Generate with a RecordFactory factory = batcher.create_record_factory(num_lines=_tok_gen_count, max_invalid=5000) syn_df = factory.generate_all(output="df") assert syn_df.shape[0] == _tok_gen_count assert list(syn_df.columns) == list(train_df.columns) assert factory.summary["valid_count"] == _tok_gen_count
def test_record_factory_exc_fail_validator(safecast_model_dir): batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir) def _validator(rec: dict): raise ValueError factory = batcher.create_record_factory(num_lines=10, validator=_validator, max_invalid=10) with pytest.raises(RuntimeError): list(factory)
def test_record_factory_exhaust_iter(safecast_model_dir): batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir) factory = batcher.create_record_factory(num_lines=10) records = list(factory) assert len(records) == 10 assert factory._counter.valid_count == 10 summary = factory.summary assert summary["num_lines"] == 10 assert summary["max_invalid"] == 1000 assert summary["valid_count"] == 10
def test_record_factory_multi_batch_seed_static(hr_model_dir): batcher = DataFrameBatch(mode="read", checkpoint_dir=hr_model_dir) factory = batcher.create_record_factory(num_lines=10, max_invalid=5000, seed_fields={"age": 5}, validator=MyValidator()) df = factory.generate_all(output="df") assert len(df) == 10 assert df["age"].nunique() == 1 assert df.iloc[0]["age"] == 5
def test_record_factory_simple_validator(safecast_model_dir): batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir) def _validator(rec: dict): # ensure returning None and True works which = random.randint(0, 1) if which: assert float(rec["payload.loc_lat"]) else: return True factory = batcher.create_record_factory(num_lines=10, validator=_validator) assert len(list(factory)) == 10
def test_record_factory_multi_batch_seed_list(hr_model_dir): batcher = DataFrameBatch(mode="read", checkpoint_dir=hr_model_dir) age_seeds = [{"age": i} for i in range(1, 51)] factory = batcher.create_record_factory( num_lines=50, # doesn't matter w/ smart seed max_invalid=5000, seed_fields=age_seeds, validator=MyValidator()) df = factory.generate_all(output="df") assert len(df) == 50 assert df["age"].nunique() == 50
def test_record_factory_generate_all(safecast_model_dir): batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir) def _validator(rec: dict): assert float(rec["payload.loc_lat"]) factory = batcher.create_record_factory(num_lines=10, validator=_validator) next(factory) next(factory) # generate_all should reset our iterator for the full 10 records assert len(factory.generate_all()) == 10 df = factory.generate_all(output="df") assert df.shape == (10, 16) assert str(df["payload.loc_lat"].dtype) == "float64"
def test_record_factory_smart_seed(safecast_model_dir): seeds = [{"payload.service_handler": "i-051a2a353509414f0"}, {"payload.service_handler": "i-051a2a353509414f1"}, {"payload.service_handler": "i-051a2a353509414f2"}, {"payload.service_handler": "i-051a2a353509414f3"}] batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir) factory = batcher.create_record_factory( num_lines=1000, seed_fields=seeds*1000 ) # list of seeds should reset num_lines assert factory.num_lines == len(seeds) * 1000 for seed, record in zip(seeds, factory): assert seed["payload.service_handler"] == record["payload.service_handler"]
def test_record_factory_smart_seed_buffer(safecast_model_dir): seeds = [{"payload.service_handler": "i-051a2a353509414f0"}, {"payload.service_handler": "i-051a2a353509414f1"}, {"payload.service_handler": "i-051a2a353509414f2"}, {"payload.service_handler": "i-051a2a353509414f3"}] * 2 batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir) factory = batcher.create_record_factory( num_lines=100, # doesn't matter w/ smart seed seed_fields=seeds, validator=MyValidator(), max_invalid=5000 ) df = factory.generate_all(output="df") assert len(df) == 8 assert factory.summary["num_lines"] == 8 assert factory.summary["valid_count"] == 8
def test_record_factory_generate_all_with_callback(safecast_model_dir, threading): batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir) def _validator(rec: dict): assert float(rec["payload.loc_lat"]) factory = batcher.create_record_factory(num_lines=1000, validator=_validator, invalid_cache_size=5) callback_fn = Mock() df = factory.generate_all(output="df", callback=callback_fn, callback_interval=1, callback_threading=threading) assert df.shape == (1000, 16) # assuming we get at least 5 bad records assert len(factory.invalid_cache) == 5 assert str(df["payload.loc_lat"].dtype) == "float64" assert callback_fn.call_count >= 2 # at least 1 call during generation and another one with final update assert callback_fn.call_count < 1000, "Progress update should be only called periodically" args, _ = callback_fn.call_args # pylint: disable=unpacking-non-sequence last_update: GenerationProgress = args[0] assert last_update.current_valid_count == 1000 assert last_update.completion_percent == 100 # calculate sum from all updates valid_total_count = 0 for call_args in callback_fn.call_args_list: args, _ = call_args valid_total_count += args[0].new_valid_count assert valid_total_count == 1000
def test_record_factory_single_line(safecast_model_dir): batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir) factory = batcher.create_record_factory(num_lines=10) record = next(factory) assert "payload.service_handler" in record assert factory._counter.valid_count == 1
def test_record_factory_bad_validator(safecast_model_dir): batcher = DataFrameBatch(mode="read", checkpoint_dir=safecast_model_dir) with pytest.raises(ValueError) as err: batcher.create_record_factory(num_lines=10, validator="foo") assert "must be callable" in str(err)