def test_generate_batch_lines_raise_on_exceed(test_data):
    batches = DataFrameBatch(df=test_data, config=config_template)
    batches.create_training_data()

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.side_effect = TooManyInvalidError()
        assert not batches.generate_batch_lines(0)

    with patch("gretel_synthetics.batch.generate_text") as mock_gen:
        mock_gen.side_effect = TooManyInvalidError()
        with pytest.raises(TooManyInvalidError):
            assert not batches.generate_batch_lines(0, raise_on_exceed_invalid=True)
    def generate_next(
        self, num_lines: int, hard_limit: Optional[int] = None
    ) -> Iterable[GenText]:
        """
        Returns a sequence of lines.

        Args:
            num_lines: the number of _valid_ lines that should be generated during this call. The actual
                number of lines returned may be higher, in case of invalid lines in the generation output.
            hard_limit: if set, imposes a hard limit on the overall number of lines that are generated during
                this call, regardless of whether the requested number of valid lines was hit.

        Yields:
            A ``gen_text`` object for every line (valid or invalid) that is generated.
        """
        valid_lines_generated = 0
        total_lines_generated = 0

        while valid_lines_generated < num_lines and (
            hard_limit is None or total_lines_generated < hard_limit
        ):
            rec = next(self._predictions).data
            total_lines_generated += 1
            _valid = None
            try:
                if not self.settings.line_validator:
                    yield GenText(
                        text=rec, valid=None, explain=None, delimiter=self.delim
                    )
                else:
                    check = self.settings.line_validator(rec)
                    if check is False:
                        _valid = False
                        self.total_invalid += 1
                    else:
                        _valid = True
                    yield GenText(
                        text=rec, valid=_valid, explain=None, delimiter=self.delim
                    )
            except Exception as err:
                # NOTE: this catches any exception raised by the line validator, which
                # also creates an invalid record
                self.total_invalid += 1
                yield GenText(
                    text=rec, valid=False, explain=str(err), delimiter=self.delim
                )
            else:
                if (self.settings.line_validator and _valid) or not self.settings.line_validator:
                    valid_lines_generated += 1
                    if self.settings.multi_seed:
                        self.settings.start_string.pop(0)
                else:
                    ...

            if self.total_invalid > self.settings.max_invalid:
                raise TooManyInvalidError("Maximum number of invalid lines reached!")
def generate_parallel(settings: Settings,
                      num_lines: int,
                      num_workers: int,
                      chunk_size: int = 5):
    """
    Runs text generation in parallel mode.

    Text generation is performed with the given settings, using a given number of parallel workers
    and a total body of work that is split into the given list of chunks.

    Args:
        settings: the settings for text generation.
        num_lines: the number of valid lines to be generated.
        num_workers: the number of parallel workers.
        chunk_size: the maximum number of lines to be assigned to a worker at once.

    Yields:
        ``gen_text`` objects.
    """

    # Create a pool of workers that will instantiate a generator upon initialization.
    worker_pool = loky.ProcessPoolExecutor(max_workers=num_workers,
                                           initializer=_loky_init_worker,
                                           initargs=(settings, ))

    # How many valid lines we still need to generate
    remaining_lines = num_lines

    # This set tracks the currently outstanding invocations to _loky_worker_process_chunk.
    pending_tasks: Set[futures.Future[Tuple[int, List[GenText], int]]] = set()  # pylint: disable=unsubscriptable-object  # noqa

    # How many tasks can be pending at once. While a lower factor saves memory, it increases the
    # risk that workers sit idle because the main process is blocked on processing data and
    # therefore cannot hand out new tasks.
    max_pending_tasks = 10 * num_workers

    # How many lines to be generated have been assigned to currently active workers. This tracks
    # the nominal/target lines, and the returned number of lines may be different if workers generate
    # a lot of invalid lines.
    assigned_lines = 0

    # The _total_ number of invalid lines we have seen so far. This is used to implement a global
    # limit on the number of invalid lines, since each worker only knows the number of invalid lines
    # it has generated itself.
    total_invalid = 0

    # hard_limit is the number of lines (valid or invalid) after which a task should give up and return
    # the intermediate result, even if it has not managed to generated the requested number of lines yet.
    # This ensures that we get frequent status updates, even if a lot of invalid lines are generated.
    # We currently hardcode this to 110% of a chunk size.
    hard_limit = int(chunk_size * 1.10)

    try:
        while remaining_lines > 0:
            # If we have capacity to add new pending tasks, do so until we are at capacity or there are
            # no more lines that can be assigned to workers.
            while len(
                    pending_tasks
            ) < max_pending_tasks and assigned_lines < remaining_lines:
                next_chunk = min(chunk_size, remaining_lines - assigned_lines)
                pending_tasks.add(
                    worker_pool.submit(_loky_worker_process_chunk, next_chunk,
                                       hard_limit))
                assigned_lines += next_chunk

            # Wait for at least one worker to complete its current task (or fail with an exception).
            completed_tasks, pending_tasks = futures.wait(
                pending_tasks, return_when=futures.FIRST_COMPLETED)

            for task in completed_tasks:
                requested_chunk_size, lines, num_invalid = task.result(
                    timeout=0)

                assigned_lines -= requested_chunk_size
                remaining_lines -= len(
                    lines) - num_invalid  # Calculate number of _valid_ lines

                # Emit lines in the output
                for line in lines:
                    if line.valid is not None and not line.valid:
                        total_invalid += 1
                    if total_invalid > settings.max_invalid:
                        raise TooManyInvalidError(
                            "Maximum number of invalid lines reached!")
                    yield line

    finally:
        # Always make sure to shut down the worker pool (no need to wait for workers to terminate).
        worker_pool.shutdown(wait=False, kill_workers=True)