def test_generate_text(_open, pickle, prepare, predict, spm,
                       global_local_config):
    global_local_config.gen_lines = 10
    predict.side_effect = [
        _pred_string(json.dumps({"foo": i})) for i in range(0, 10)
    ]
    out = []

    sp = Mock()
    spm.return_value = sp

    for rec in generate_text(global_local_config, line_validator=json.loads):
        out.append(rec.as_dict())

    assert len(out) == 10
    assert out[0] == {
        "valid": True,
        "text": '{"foo": 0}',
        "explain": None,
        "delimiter": ",",
    }

    # now with no validator
    predict.side_effect = [
        _pred_string(json.dumps({"foo": i})) for i in range(0, 10)
    ]
    out = []
    for rec in generate_text(global_local_config):
        out.append(rec.as_dict())
    assert len(out) == 10
    assert out[0] == {
        "valid": None,
        "text": '{"foo": 0}',
        "explain": None,
        "delimiter": ",",
    }

    # add validator back in, with a few bad json strings
    predict.side_effect = (
        [_pred_string(json.dumps({"foo": i})) for i in range(0, 3)] +
        [_pred_string("nope"),
         _pred_string("foo"),
         _pred_string("bar")] +
        [_pred_string(json.dumps({"foo": i})) for i in range(6, 10)])
    out = []
    for rec in generate_text(global_local_config, line_validator=json.loads):
        out.append(rec.as_dict())
    assert len(out) == 10
    assert not out[4]["valid"]
예제 #2
0
    def generate_batch_lines(
        self,
        batch_idx: int,
        max_invalid=MAX_INVALID,
        raise_on_exceed_invalid: bool = False,
        num_lines: int = None,
    ) -> bool:
        """Generate lines for a single batch. Lines generated are added
        to the underlying ``Batch`` object for each batch. The lines
        can be accessed after generation and re-assembled into a DataFrame.

        Args:
            batch_idx: The batch number
            max_invalid: The max number of invalid lines that can be generated, if
                this is exceeded, generation will stop
            raise_on_exceed_invalid: If true and if the number of lines generated exceeds the ``max_invalid``
                amount, we will re-raise the error thrown by the generation module which will interrupt
                the running process. Otherwise, we will not raise the caught exception and just return ``False``
                indicating that the batch failed to generate all lines.
            num_lines: The number of lines to generate, if ``None``, then we use the number from the
                batch's config
        """
        try:
            batch = self.batches[batch_idx]
        except KeyError:  # pragma: no cover
            raise ValueError("invalid batch index")
        batch: Batch
        batch.reset_gen_data()
        validator = batch.get_validator()
        if num_lines is None:
            num_lines = batch.config.gen_lines
        t = tqdm(total=num_lines, desc="Valid record count ")
        t2 = tqdm(total=max_invalid, desc="Invalid record count ")
        line: gen_text
        try:
            for line in generate_text(batch.config,
                                      line_validator=validator,
                                      max_invalid=max_invalid,
                                      num_lines=num_lines):
                if line.valid is None or line.valid is True:
                    batch.add_valid_data(line)
                    t.update(1)
                else:
                    t2.update(1)
                    batch.gen_data_invalid.append(line)
        except RuntimeError:
            if raise_on_exceed_invalid:
                raise
            else:
                return False
        t.close()
        t2.close()
        return batch.gen_data_count == num_lines
예제 #3
0
 def _generate(
     self, model_dir: Path, count: int, file_name: str, seed, validator
 ) -> str:
     batch_mode = is_model_dir_batch_mode(model_dir)
     if batch_mode:
         if seed is not None and not isinstance(seed, dict):
             raise TypeError("Seed must be a dict in batch mode")
         out_fname = f"{file_name}.csv"
         batcher = DataFrameBatch(mode="read", checkpoint_dir=str(model_dir))
         batcher.generate_all_batch_lines(
             num_lines=count,
             max_invalid=max(count, MAX_INVALID),
             parallelism=1,
             seed_fields=seed
         )
         out_df = batcher.batches_to_df()
         out_df.to_csv(out_fname, index=False)
         return out_fname
     else:
         out = []
         # The model data will be inside of a single directory when a simple model is used. If it
         # was archived correctly, there should only be a single directory inside the archive
         actual_dir = next(model_dir.glob("*"))
         config = config_from_model_dir(actual_dir)
         if seed is not None and not isinstance(seed, str):
             raise TypeError("seed must be a string")
         for data in generate_text(
             config,
             num_lines=count,
             line_validator=validator,
             max_invalid=max(count, MAX_INVALID),
             parallelism=1,
             start_string=seed
         ):
             if data.valid or data.valid is None:
                 out.append(data.text)
         out_fname = file_name + ".txt"
         with open(out_fname, "w") as fout:
             for line in out:
                 fout.write(line + "\n")
         return out_fname
예제 #4
0
def start():
    for line in generate_text(config,
                              line_validator=validate_record,
                              parallelism=PARALLELISM):
        print(line)
예제 #5
0
    def generate_batch_lines(
        self,
        batch_idx: int,
        max_invalid=MAX_INVALID,
        raise_on_exceed_invalid: bool = False,
        num_lines: int = None,
        seed_fields: dict = None,
        parallelism: int = 0,
    ) -> bool:
        """Generate lines for a single batch. Lines generated are added
        to the underlying ``Batch`` object for each batch. The lines
        can be accessed after generation and re-assembled into a DataFrame.

        Args:
            batch_idx: The batch number
            max_invalid: The max number of invalid lines that can be generated, if
                this is exceeded, generation will stop
            raise_on_exceed_invalid: If true and if the number of lines generated exceeds the ``max_invalid``
                amount, we will re-raise the error thrown by the generation module which will interrupt
                the running process. Otherwise, we will not raise the caught exception and just return ``False``
                indicating that the batch failed to generate all lines.
            num_lines: The number of lines to generate, if ``None``, then we use the number from the
                batch's config
            seed_fields: A dictionary that maps field/column names to initial seed values for those columns. This seed
                will only apply to the first batch that gets trained and generated. Additionally, the fields provided
                in the mapping MUST exist at the front of the first batch.
            parallelism: The number of concurrent workers to use. ``1`` (the default) disables parallelization,
                while a non-positive value means "number of CPUs + x" (i.e., use ``0`` for using as many workers
                as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs,
                rounded down.
        """
        try:
            batch = self.batches[batch_idx]
        except KeyError:  # pragma: no cover
            raise ValueError("invalid batch index")

        seed_string = None

        # If we are on batch 0 and we have seed values, we want to validate that
        # the seed values line up properly with the first N columns.
        if batch_idx == 0 and seed_fields is not None:
            seed_string = self._validate_batch_seed_values(batch, seed_fields)

        batch: Batch
        batch.reset_gen_data()
        validator = batch.get_validator()
        if num_lines is None:
            num_lines = batch.config.gen_lines
        t = tqdm(total=num_lines, desc="Valid record count ")
        t2 = tqdm(total=max_invalid, desc="Invalid record count ")
        line: GenText
        try:
            for line in generate_text(
                batch.config,
                line_validator=validator,
                max_invalid=max_invalid,
                num_lines=num_lines,
                start_string=seed_string,
                parallelism=parallelism,
            ):
                if line.valid is None or line.valid is True:
                    batch.add_valid_data(line)
                    t.update(1)
                else:
                    t2.update(1)
                    batch.gen_data_invalid.append(line)
        except TooManyInvalidError:
            if raise_on_exceed_invalid:
                raise
            else:
                return False
        t.close()
        t2.close()
        return batch.gen_data_count >= num_lines
예제 #6
0
    def _get_record(self) -> IteratorType[dict]:
        # our actual batch line generators
        generators = []

        # if we have a list of seed fields, we do special
        # handling to create the proper generator
        seed_generator = None  # assume no seeds to start
        if isinstance(self._seed_fields, list):
            seed_generator = SeedingGenerator(
                self._batches[0].config,
                seed_list=self._seed_fields,
                line_validator=self._batches[0].get_validator(),
                max_invalid=self.max_invalid * 10000)
            generators.append((self._batches[0], seed_generator))

        for idx, batch in self._batches.items():
            start_string = None
            if idx == 0 and seed_generator:
                # We've already added the first batch's generator to the list
                # so we just continue on to the next one
                continue
            if idx == 0:
                # In the event we have seeds that aren't a list, (i.e. static seeds)
                start_string = self._seed_fields
            generators.append((
                batch,
                # We seed the low level API with much higher limits on
                # valid / invalid generation because we will enforce
                # those limits in this high level instance.
                generate_text(
                    batch.config,
                    line_validator=batch.get_validator(),
                    max_invalid=self.max_invalid * 10000,
                    num_lines=self.num_lines * 10000,
                    start_string=start_string,
                    parallelism=self._parallelism,
                ),
            ))

        # At this point, we've created our list of generators. Below here
        # is what gets run on every next() call, which tries to construct
        # a full record from all the underlying batches.

        # keep looping as long as our target line count is less than
        # our total line count
        while self.valid_count < self.num_lines:
            # loop over each batch line generater and attempt
            # to construct a full line, we'll only count a
            # full line once we get through each generator
            if self.invalid_count >= self.max_invalid:
                raise RuntimeError(
                    "Invalid record count exceeded during generation")

            seed_cache = None
            if seed_generator:
                # If we're using a seeding generator (from a list of seeds)
                # we cache the next seed we are about to use to generate
                # the next record.
                seed_cache = seed_generator.settings.start_string[0]

            record = {}
            batch: Batch
            for batch, gen in generators:
                while True:
                    line = next(gen)  # type:  GenText
                    if line.valid is False:
                        self.invalid_count += 1
                        if self.invalid_count > self.max_invalid:
                            raise RuntimeError(
                                "Invalid record count exceeded during generation"
                            )
                        continue
                    partial_rec = dict(
                        zip(batch.headers, line.values_as_list()))
                    record.update(partial_rec)
                    break

            # Do a final validation, if configured, on the fully constructed
            # record, if this validation fails, we'll still increment our
            # invalid count.

            valid = True  # assume we have a valid record

            if self.validator is not None:
                try:
                    _valid = self.validator(record)
                    if _valid is False:
                        valid = False
                except Exception:
                    valid = False

            if not valid:
                self.invalid_count += 1
                if seed_cache:
                    seed_generator.settings.start_string.insert(0, seed_cache)
                continue  # back to the while start

            self.valid_count += 1
            yield record
예제 #7
0
def test_generate_text(_open, pickle, prepare, predict, spm, tf_config):
    tf_config.gen_lines = 10
    predict.side_effect = [[PredString(json.dumps({"foo": i}))] for i in range(0, 10)]
    out = []

    tokenizer = Mock()
    spm.return_value = tokenizer

    for rec in generate_text(tf_config, line_validator=json.loads, parallelism=1):
        out.append(rec.as_dict())

    assert len(out) == 10
    assert out[0] == {
        "valid": True,
        "text": '{"foo": 0}',
        "explain": None,
        "delimiter": ",",
    }

    # now with no validator
    predict.side_effect = [[PredString(json.dumps({"foo": i}))] for i in range(0, 10)]
    out = []
    for rec in generate_text(tf_config, parallelism=1):
        out.append(rec.as_dict())
    assert len(out) == 10
    assert out[0] == {
        "valid": None,
        "text": '{"foo": 0}',
        "explain": None,
        "delimiter": ",",
    }

    # add validator back in, with a few bad json strings
    predict.side_effect = [
        [PredString(json.dumps({"foo": i})) for i in range(0, 3)],
        [PredString("nope"), PredString("foo"), PredString("bar")],
        [PredString(json.dumps({"foo": i})) for i in range(6, 10)],
    ]
    out = []
    try:
        for rec in generate_text(tf_config, line_validator=json.loads, parallelism=1):
            out.append(rec.as_dict())
    except RuntimeError:
        pass
    assert len(out) == 10
    assert not out[4]["valid"]

    # assert max invalid
    predict.side_effect = [
        [PredString(json.dumps({"foo": i})) for i in range(0, 3)],
        [PredString("nope"), PredString("foo"), PredString("bar")],
        [PredString(json.dumps({"foo": i})) for i in range(6, 10)],
    ]
    out = []
    try:
        for rec in generate_text(tf_config, line_validator=json.loads, max_invalid=2, parallelism=1):
            out.append(rec.as_dict())
    except RuntimeError as err:
        assert "Maximum number" in str(err)
    assert len(out) == 6
    assert not out[4]["valid"]

    # max invalid, validator returns a bool
    def _val(line):
        try:
            json.loads(line)
        except Exception:
            return False
        else:
            return True

    predict.side_effect = [
        [PredString(json.dumps({"foo": i})) for i in range(0, 3)],
        [PredString("nope"), PredString("foo"), PredString("bar")],
        [PredString(json.dumps({"foo": i})) for i in range(6, 10)],
    ]
    out = []
    try:
        for rec in generate_text(tf_config, line_validator=_val, max_invalid=2, parallelism=1):
            out.append(rec.as_dict())
    except RuntimeError as err:
        assert "Maximum number" in str(err)
    assert len(out) == 6
    assert not out[4]["valid"]