def test_train_batch_sp_tok(train_df, tmp_path): config = TensorFlowConfig( epochs=5, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER, learning_rate=.01 ) tokenizer = SentencePieceTokenizerTrainer( vocab_size=10000, config=config ) batcher = DataFrameBatch( df=train_df, config=config, tokenizer=tokenizer ) batcher.create_training_data() batcher.train_all_batches() batcher.generate_all_batch_lines(num_lines=_tok_gen_count, max_invalid=5000) syn_df = batcher.batches_to_df() assert syn_df.shape[0] == _tok_gen_count # Generate with a RecordFactory factory = batcher.create_record_factory(num_lines=_tok_gen_count, max_invalid=5000) syn_df = factory.generate_all(output="df") assert syn_df.shape[0] == _tok_gen_count assert list(syn_df.columns) == list(train_df.columns) assert factory.summary["valid_count"] == _tok_gen_count
def test_train_batch_sp_tok(train_df, tmp_path): config = TensorFlowConfig(epochs=5, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER, learning_rate=.01) tokenizer = SentencePieceTokenizerTrainer(vocab_size=10000, config=config) batcher = DataFrameBatch(df=train_df, config=config, tokenizer=tokenizer) batcher.create_training_data() batcher.train_all_batches() batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000) syn_df = batcher.batches_to_df() assert syn_df.shape[0] == 100
def test_batches_to_df(test_data): batches = DataFrameBatch(df=pd.DataFrame([ {"foo": "bar", "foo1": "bar1", "foo2": "bar2", "foo3": 3}]), config=config_template, batch_size=2) batches.batches[0].add_valid_data( GenText(text="baz|baz1", valid=True, delimiter="|") ) batches.batches[1].add_valid_data( GenText(text="baz2|5", valid=True, delimiter="|") ) check = batches.batches_to_df() assert list(check.columns) == ["foo", "foo1", "foo2", "foo3"] assert check.shape == (1, 4) assert [t.name for t in list(check.dtypes)] == ['object', 'object', 'object', 'int64']
def _fit_sample(self, data, metadata): config = { 'max_lines': self.max_lines, 'max_line_len': self.max_line_len, 'epochs': self.epochs or data.shape[1] * 3, # value recommended by Gretel 'vocab_size': self.vocab_size, 'gen_lines': self.gen_lines or data.shape[0], 'dp': self.dp, 'field_delimiter': self.field_delimiter, 'overwrite': self.overwrite, 'checkpoint_dir': self.checkpoint_dir } batcher = DataFrameBatch(df=data, config=config) batcher.create_training_data() batcher.train_all_batches() batcher.generate_all_batch_lines() synth_data = batcher.batches_to_df() return synth_data
def test_train_batch_char_tok(train_df, tmp_path): config = TensorFlowConfig(epochs=5, field_delimiter=",", checkpoint_dir=tmp_path, input_data_path=PATH_HOLDER, learning_rate=.01) batcher = DataFrameBatch(df=train_df, config=config, tokenizer=CharTokenizerTrainer(config=config)) batcher.create_training_data() batcher.train_all_batches() tok_params = json.loads( open(tmp_path / "batch_0" / BaseTokenizerTrainer.settings_fname).read()) assert tok_params["tokenizer_type"] == CharTokenizerTrainer.__name__ batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000) syn_df = batcher.batches_to_df() assert syn_df.shape[0] == 100
def _generate( self, model_dir: Path, count: int, file_name: str, seed, validator ) -> str: batch_mode = is_model_dir_batch_mode(model_dir) if batch_mode: if seed is not None and not isinstance(seed, dict): raise TypeError("Seed must be a dict in batch mode") out_fname = f"{file_name}.csv" batcher = DataFrameBatch(mode="read", checkpoint_dir=str(model_dir)) batcher.generate_all_batch_lines( num_lines=count, max_invalid=max(count, MAX_INVALID), parallelism=1, seed_fields=seed ) out_df = batcher.batches_to_df() out_df.to_csv(out_fname, index=False) return out_fname else: out = [] # The model data will be inside of a single directory when a simple model is used. If it # was archived correctly, there should only be a single directory inside the archive actual_dir = next(model_dir.glob("*")) config = config_from_model_dir(actual_dir) if seed is not None and not isinstance(seed, str): raise TypeError("seed must be a string") for data in generate_text( config, num_lines=count, line_validator=validator, max_invalid=max(count, MAX_INVALID), parallelism=1, start_string=seed ): if data.valid or data.valid is None: out.append(data.text) out_fname = file_name + ".txt" with open(out_fname, "w") as fout: for line in out: fout.write(line + "\n") return out_fname