def test_generate_text(_open, pickle, prepare, predict, spm, global_local_config): global_local_config.gen_lines = 10 predict.side_effect = [ _pred_string(json.dumps({"foo": i})) for i in range(0, 10) ] out = [] sp = Mock() spm.return_value = sp for rec in generate_text(global_local_config, line_validator=json.loads): out.append(rec.as_dict()) assert len(out) == 10 assert out[0] == { "valid": True, "text": '{"foo": 0}', "explain": None, "delimiter": ",", } # now with no validator predict.side_effect = [ _pred_string(json.dumps({"foo": i})) for i in range(0, 10) ] out = [] for rec in generate_text(global_local_config): out.append(rec.as_dict()) assert len(out) == 10 assert out[0] == { "valid": None, "text": '{"foo": 0}', "explain": None, "delimiter": ",", } # add validator back in, with a few bad json strings predict.side_effect = ( [_pred_string(json.dumps({"foo": i})) for i in range(0, 3)] + [_pred_string("nope"), _pred_string("foo"), _pred_string("bar")] + [_pred_string(json.dumps({"foo": i})) for i in range(6, 10)]) out = [] for rec in generate_text(global_local_config, line_validator=json.loads): out.append(rec.as_dict()) assert len(out) == 10 assert not out[4]["valid"]
def generate_batch_lines( self, batch_idx: int, max_invalid=MAX_INVALID, raise_on_exceed_invalid: bool = False, num_lines: int = None, ) -> bool: """Generate lines for a single batch. Lines generated are added to the underlying ``Batch`` object for each batch. The lines can be accessed after generation and re-assembled into a DataFrame. Args: batch_idx: The batch number max_invalid: The max number of invalid lines that can be generated, if this is exceeded, generation will stop raise_on_exceed_invalid: If true and if the number of lines generated exceeds the ``max_invalid`` amount, we will re-raise the error thrown by the generation module which will interrupt the running process. Otherwise, we will not raise the caught exception and just return ``False`` indicating that the batch failed to generate all lines. num_lines: The number of lines to generate, if ``None``, then we use the number from the batch's config """ try: batch = self.batches[batch_idx] except KeyError: # pragma: no cover raise ValueError("invalid batch index") batch: Batch batch.reset_gen_data() validator = batch.get_validator() if num_lines is None: num_lines = batch.config.gen_lines t = tqdm(total=num_lines, desc="Valid record count ") t2 = tqdm(total=max_invalid, desc="Invalid record count ") line: gen_text try: for line in generate_text(batch.config, line_validator=validator, max_invalid=max_invalid, num_lines=num_lines): if line.valid is None or line.valid is True: batch.add_valid_data(line) t.update(1) else: t2.update(1) batch.gen_data_invalid.append(line) except RuntimeError: if raise_on_exceed_invalid: raise else: return False t.close() t2.close() return batch.gen_data_count == num_lines
def _generate( self, model_dir: Path, count: int, file_name: str, seed, validator ) -> str: batch_mode = is_model_dir_batch_mode(model_dir) if batch_mode: if seed is not None and not isinstance(seed, dict): raise TypeError("Seed must be a dict in batch mode") out_fname = f"{file_name}.csv" batcher = DataFrameBatch(mode="read", checkpoint_dir=str(model_dir)) batcher.generate_all_batch_lines( num_lines=count, max_invalid=max(count, MAX_INVALID), parallelism=1, seed_fields=seed ) out_df = batcher.batches_to_df() out_df.to_csv(out_fname, index=False) return out_fname else: out = [] # The model data will be inside of a single directory when a simple model is used. If it # was archived correctly, there should only be a single directory inside the archive actual_dir = next(model_dir.glob("*")) config = config_from_model_dir(actual_dir) if seed is not None and not isinstance(seed, str): raise TypeError("seed must be a string") for data in generate_text( config, num_lines=count, line_validator=validator, max_invalid=max(count, MAX_INVALID), parallelism=1, start_string=seed ): if data.valid or data.valid is None: out.append(data.text) out_fname = file_name + ".txt" with open(out_fname, "w") as fout: for line in out: fout.write(line + "\n") return out_fname
def start(): for line in generate_text(config, line_validator=validate_record, parallelism=PARALLELISM): print(line)
def generate_batch_lines( self, batch_idx: int, max_invalid=MAX_INVALID, raise_on_exceed_invalid: bool = False, num_lines: int = None, seed_fields: dict = None, parallelism: int = 0, ) -> bool: """Generate lines for a single batch. Lines generated are added to the underlying ``Batch`` object for each batch. The lines can be accessed after generation and re-assembled into a DataFrame. Args: batch_idx: The batch number max_invalid: The max number of invalid lines that can be generated, if this is exceeded, generation will stop raise_on_exceed_invalid: If true and if the number of lines generated exceeds the ``max_invalid`` amount, we will re-raise the error thrown by the generation module which will interrupt the running process. Otherwise, we will not raise the caught exception and just return ``False`` indicating that the batch failed to generate all lines. num_lines: The number of lines to generate, if ``None``, then we use the number from the batch's config seed_fields: A dictionary that maps field/column names to initial seed values for those columns. This seed will only apply to the first batch that gets trained and generated. Additionally, the fields provided in the mapping MUST exist at the front of the first batch. parallelism: The number of concurrent workers to use. ``1`` (the default) disables parallelization, while a non-positive value means "number of CPUs + x" (i.e., use ``0`` for using as many workers as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs, rounded down. """ try: batch = self.batches[batch_idx] except KeyError: # pragma: no cover raise ValueError("invalid batch index") seed_string = None # If we are on batch 0 and we have seed values, we want to validate that # the seed values line up properly with the first N columns. if batch_idx == 0 and seed_fields is not None: seed_string = self._validate_batch_seed_values(batch, seed_fields) batch: Batch batch.reset_gen_data() validator = batch.get_validator() if num_lines is None: num_lines = batch.config.gen_lines t = tqdm(total=num_lines, desc="Valid record count ") t2 = tqdm(total=max_invalid, desc="Invalid record count ") line: GenText try: for line in generate_text( batch.config, line_validator=validator, max_invalid=max_invalid, num_lines=num_lines, start_string=seed_string, parallelism=parallelism, ): if line.valid is None or line.valid is True: batch.add_valid_data(line) t.update(1) else: t2.update(1) batch.gen_data_invalid.append(line) except TooManyInvalidError: if raise_on_exceed_invalid: raise else: return False t.close() t2.close() return batch.gen_data_count >= num_lines
def _get_record(self) -> IteratorType[dict]: # our actual batch line generators generators = [] # if we have a list of seed fields, we do special # handling to create the proper generator seed_generator = None # assume no seeds to start if isinstance(self._seed_fields, list): seed_generator = SeedingGenerator( self._batches[0].config, seed_list=self._seed_fields, line_validator=self._batches[0].get_validator(), max_invalid=self.max_invalid * 10000) generators.append((self._batches[0], seed_generator)) for idx, batch in self._batches.items(): start_string = None if idx == 0 and seed_generator: # We've already added the first batch's generator to the list # so we just continue on to the next one continue if idx == 0: # In the event we have seeds that aren't a list, (i.e. static seeds) start_string = self._seed_fields generators.append(( batch, # We seed the low level API with much higher limits on # valid / invalid generation because we will enforce # those limits in this high level instance. generate_text( batch.config, line_validator=batch.get_validator(), max_invalid=self.max_invalid * 10000, num_lines=self.num_lines * 10000, start_string=start_string, parallelism=self._parallelism, ), )) # At this point, we've created our list of generators. Below here # is what gets run on every next() call, which tries to construct # a full record from all the underlying batches. # keep looping as long as our target line count is less than # our total line count while self.valid_count < self.num_lines: # loop over each batch line generater and attempt # to construct a full line, we'll only count a # full line once we get through each generator if self.invalid_count >= self.max_invalid: raise RuntimeError( "Invalid record count exceeded during generation") seed_cache = None if seed_generator: # If we're using a seeding generator (from a list of seeds) # we cache the next seed we are about to use to generate # the next record. seed_cache = seed_generator.settings.start_string[0] record = {} batch: Batch for batch, gen in generators: while True: line = next(gen) # type: GenText if line.valid is False: self.invalid_count += 1 if self.invalid_count > self.max_invalid: raise RuntimeError( "Invalid record count exceeded during generation" ) continue partial_rec = dict( zip(batch.headers, line.values_as_list())) record.update(partial_rec) break # Do a final validation, if configured, on the fully constructed # record, if this validation fails, we'll still increment our # invalid count. valid = True # assume we have a valid record if self.validator is not None: try: _valid = self.validator(record) if _valid is False: valid = False except Exception: valid = False if not valid: self.invalid_count += 1 if seed_cache: seed_generator.settings.start_string.insert(0, seed_cache) continue # back to the while start self.valid_count += 1 yield record
def test_generate_text(_open, pickle, prepare, predict, spm, tf_config): tf_config.gen_lines = 10 predict.side_effect = [[PredString(json.dumps({"foo": i}))] for i in range(0, 10)] out = [] tokenizer = Mock() spm.return_value = tokenizer for rec in generate_text(tf_config, line_validator=json.loads, parallelism=1): out.append(rec.as_dict()) assert len(out) == 10 assert out[0] == { "valid": True, "text": '{"foo": 0}', "explain": None, "delimiter": ",", } # now with no validator predict.side_effect = [[PredString(json.dumps({"foo": i}))] for i in range(0, 10)] out = [] for rec in generate_text(tf_config, parallelism=1): out.append(rec.as_dict()) assert len(out) == 10 assert out[0] == { "valid": None, "text": '{"foo": 0}', "explain": None, "delimiter": ",", } # add validator back in, with a few bad json strings predict.side_effect = [ [PredString(json.dumps({"foo": i})) for i in range(0, 3)], [PredString("nope"), PredString("foo"), PredString("bar")], [PredString(json.dumps({"foo": i})) for i in range(6, 10)], ] out = [] try: for rec in generate_text(tf_config, line_validator=json.loads, parallelism=1): out.append(rec.as_dict()) except RuntimeError: pass assert len(out) == 10 assert not out[4]["valid"] # assert max invalid predict.side_effect = [ [PredString(json.dumps({"foo": i})) for i in range(0, 3)], [PredString("nope"), PredString("foo"), PredString("bar")], [PredString(json.dumps({"foo": i})) for i in range(6, 10)], ] out = [] try: for rec in generate_text(tf_config, line_validator=json.loads, max_invalid=2, parallelism=1): out.append(rec.as_dict()) except RuntimeError as err: assert "Maximum number" in str(err) assert len(out) == 6 assert not out[4]["valid"] # max invalid, validator returns a bool def _val(line): try: json.loads(line) except Exception: return False else: return True predict.side_effect = [ [PredString(json.dumps({"foo": i})) for i in range(0, 3)], [PredString("nope"), PredString("foo"), PredString("bar")], [PredString(json.dumps({"foo": i})) for i in range(6, 10)], ] out = [] try: for rec in generate_text(tf_config, line_validator=_val, max_invalid=2, parallelism=1): out.append(rec.as_dict()) except RuntimeError as err: assert "Maximum number" in str(err) assert len(out) == 6 assert not out[4]["valid"]