Ejemplo n.º 1
0
def train(store: BaseConfig,
          tokenizer_trainer: Optional[BaseTokenizerTrainer] = None):
    """Train a Synthetic Model.  This is a facade entrypoint that implements the engine
    specific training operation based on the provided configuration.

    Args:
        store: A subclass instance of ``BaseConfig.`` This config is reponsible for
            providing the actual training entrypoint for a specific training routine.

        tokenizer_trainer: An optional subclass instance of a ``BaseTokenizerTrainer``.  If provided
            this tokenizer will be used to pre-process and create an annotated dataset for training.
            If not provided a default tokenizer will be used.
    """
    if tokenizer_trainer is None:
        tokenizer_trainer = _create_default_tokenizer(store)
    tokenizer_trainer.create_annotated_training_data()
    tokenizer_trainer.train()
    tokenizer = tokenizer_from_model_dir(store.checkpoint_dir)
    params = TrainingParams(tokenizer_trainer=tokenizer_trainer,
                            tokenizer=tokenizer,
                            config=store)
    train_fn = store.get_training_callable()
    store.save_model_params()
    store.gpu_check()
    train_fn(params)
Ejemplo n.º 2
0
def test_sp_field_delim(input_data_path, tmpdir):
    config = SimpleConfig(input_data_path=input_data_path,
                          checkpoint_dir=tmpdir,
                          field_delimiter=",")
    trainer = tok.SentencePieceTokenizerTrainer(config=config)
    line_iter = trainer.create_annotated_training_data()

    line_one = next(line_iter)
    assert line_one == "Once upon a midnight dreary<d> while I pondered<d> weak and weary<d><n>\n"

    trainer.train()
    tokenizer = tok.SentencePieceTokenizer.load(tmpdir)

    ids = [
        40, 53, 7, 5, 10, 35, 9, 13, 15, 12, 16, 15, 21, 19, 14, 5, 12, 24, 30,
        6, 4, 51, 41, 8, 7, 5, 23, 5, 35, 12, 47, 12, 4, 48, 61, 9, 27, 48, 24,
        6, 4, 3
    ]
    assert tokenizer.encode_to_ids(
        "Once upon a midnight dreary<d> while I pondered<d> weak and weary<d><n>\n"
    ) == ids
    assert tokenizer.decode_from_ids(
        ids
    ) == "Once upon a midnight dreary, while I pondered, weak and weary,<n>"

    # Check the factory
    assert isinstance(tok.tokenizer_from_model_dir(tmpdir),
                      tok.SentencePieceTokenizer)
Ejemplo n.º 3
0
    def __init__(
        self,
        config: BaseConfig,
        *,
        seed_list: List[str],
        line_validator: Optional[Callable] = None,
        max_invalid: int = 1000,
    ):
        generator_class = config.get_generator_class()
        tokenizer = tokenizer_from_model_dir(config.checkpoint_dir)

        self.settings = Settings(
            config=config,
            start_string=seed_list,
            line_validator=line_validator,
            max_invalid=max_invalid,
            tokenizer=tokenizer,
        )

        self._generator = generator_class(self.settings).generate_next(None)
Ejemplo n.º 4
0
def test_single_char(input_data_path, tmpdir):
    # NOTE: Here the line delim should not matter for this char tokenizer
    config = SimpleConfig(input_data_path=input_data_path,
                          checkpoint_dir=tmpdir,
                          field_delimiter=",")
    trainer = tok.CharTokenizerTrainer(config=config)

    # We need this for batch mode, so verify it can be copied
    deepcopy(trainer)

    line_iter = trainer.create_annotated_training_data()

    # Assert that we didn't do any annotation
    line_one = next(line_iter)
    assert line_one == L1

    # Let's train the tokenizer, and now reload it back in
    trainer.train()
    tokenizer = tok.CharTokenizer.load(tmpdir)
    assert tokenizer.total_vocab_size == 32

    # NOTE: this is because we default to using this token as a delim
    # in the main config, but this tokenizer doesn't do anything with it anyway
    assert tokenizer.field_delimiter == ","
    assert tokenizer.field_delimiter_token == "<d>"

    l1_ids = [
        6, 21, 11, 13, 1, 28, 23, 22, 21, 1, 9, 1, 20, 17, 12, 21, 17, 15, 16,
        27, 1, 12, 25, 13, 9, 25, 31, 2, 1, 30, 16, 17, 19, 13, 1, 5, 1, 23,
        22, 21, 12, 13, 25, 13, 12, 2, 1, 30, 13, 9, 18, 1, 9, 21, 12, 1, 30,
        13, 9, 25, 31, 2, 0
    ]
    assert tokenizer.encode_to_ids(L1) == l1_ids
    assert tokenizer.decode_from_ids(l1_ids) == L1

    # Check the factory
    assert isinstance(tok.tokenizer_from_model_dir(tmpdir), tok.CharTokenizer)
Ejemplo n.º 5
0
def generate_text(config: BaseConfig,
                  start_string: Optional[str] = None,
                  line_validator: Optional[Callable] = None,
                  max_invalid: int = 1000,
                  num_lines: Optional[int] = None,
                  parallelism: int = 0) -> Iterator[GenText]:
    """A generator that will load a model and start creating records.

    Args:
        config: A configuration object, which you must have created previously
        start_string:  A prefix string that is used to seed the record generation.
            By default we use a newline, but you may substitue any initial value here
            which will influence how the generator predicts what to generate. If you
            are working with a field delimiter, and you want to seed more than one column
            value, then you MUST utilize the field delimiter specified in your config.
            An example would be "foo,bar,baz,". Also, if using a field delimiter, the string
            MUST end with the delimiter value.
        line_validator: An optional callback validator function that will take
            the raw string value from the generator as a single argument. This validator
            can executue arbitrary code with the raw string value. The validator function
            may return a bool to indicate line validity. This boolean value will be set
            on the yielded ``gen_text`` object. Additionally, if the validator throws an
            exception, the ``gen_text`` object will be set with a failed validation. If the
            validator returns None, we will assume successful validation.
        max_invalid: If using a ``line_validator``, this is the maximum number of invalid
            lines to generate. If the number of invalid lines exceeds this value a ``RunTimeError``
            will be raised.
        num_lines: If not ``None``, this will override the ``gen_lines`` value that is provided in the ``config``
        parallelism: The number of concurrent workers to use. ``1`` (the default) disables parallelization,
            while a non-positive value means "number of CPUs + x" (i.e., use ``0`` for using as many workers
            as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs,
            rounded down.

    Simple validator example::

        def my_validator(raw_line: str):
            parts = raw_line.split(',')
            if len(parts) != 5:
                raise Exception('record does not have 5 fields')

    NOTE:
        ``gen_lines`` from the ``config`` is important for this function. If a line validator is not provided,
        each line will count towards the number of total generated lines. When the total lines generated is >=
        ``gen_lines`` we stop. If a line validator is provided, only *valid* lines will count towards
        the total number of lines generated. When the total number of valid lines generated is >= ``gen_lines``,
        we stop.

    NOTE:
        ``gen_chars``, controls the possible maximum number of characters a single
        generated line can have. If a newline character has not been generated before reaching
        this number, then the line will be returned. For example if ``gen_chars`` is 180 and a
        newline has not been generated, once 180 chars have been created, the line will be returned
        no matter what. As a note, if this value is 0, then each line will generate until
        a newline is observed.

    Yields:
        A ``GenText`` object for each record that is generated. The generator
        will stop after the max number of lines is reached (based on your config).

    Raises:
        A  ``RunTimeError`` if the ``max_invalid`` number of lines is generated

    """

    generator_class = config.get_generator_class()
    tokenizer = tokenizer_from_model_dir(config.checkpoint_dir)

    settings = Settings(config=config,
                        start_string=start_string,
                        line_validator=line_validator,
                        max_invalid=max_invalid,
                        generator=generator_class,
                        tokenizer=tokenizer)

    if num_lines is not None:
        _line_count = num_lines
    else:
        _line_count = config.gen_lines

    num_workers = get_num_workers(parallelism, _line_count, chunk_size=5)
    if num_workers == 1:
        gen = generator_class(settings)
        yield from gen.generate_next(_line_count)
    else:
        yield from generate_parallel(settings,
                                     _line_count,
                                     num_workers,
                                     chunk_size=5)