Exemplo n.º 1
0
def test_train_batch_sp_tok(train_df, tmp_path):
    config = TensorFlowConfig(
        epochs=5,
        field_delimiter=",",
        checkpoint_dir=tmp_path,
        input_data_path=PATH_HOLDER,
        learning_rate=.01
    )
    tokenizer = SentencePieceTokenizerTrainer(
        vocab_size=10000,
        config=config
    )
    batcher = DataFrameBatch(
        df=train_df,
        config=config,
        tokenizer=tokenizer
    )
    batcher.create_training_data()
    batcher.train_all_batches()

    batcher.generate_all_batch_lines(num_lines=_tok_gen_count, max_invalid=5000)
    syn_df = batcher.batches_to_df()
    assert syn_df.shape[0] == _tok_gen_count

    # Generate with a RecordFactory
    factory = batcher.create_record_factory(num_lines=_tok_gen_count, max_invalid=5000)
    syn_df = factory.generate_all(output="df")
    assert syn_df.shape[0] == _tok_gen_count
    assert list(syn_df.columns) == list(train_df.columns)
    assert factory.summary["valid_count"] == _tok_gen_count
Exemplo n.º 2
0
def test_epoch_callback(train_df, tmp_path):
    def epoch_callback(s: EpochState):
        with open(tmp_path / 'callback_dump.txt', 'a') as f:
            f.write(f'{s.epoch},{s.accuracy},{s.loss},{s.batch}\n')

    config = TensorFlowConfig(epochs=5,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER,
                              learning_rate=.01,
                              epoch_callback=epoch_callback)
    tokenizer = SentencePieceTokenizerTrainer(vocab_size=10000, config=config)
    batcher = DataFrameBatch(batch_size=4,
                             df=train_df,
                             config=config,
                             tokenizer=tokenizer)
    batcher.create_training_data()
    batcher.train_all_batches()
    with open(tmp_path / 'callback_dump.txt', 'r') as f:
        lines = f.readlines()
        assert len(lines) == 20
        for i, line in enumerate(lines):
            fields = line.strip().split(',')
            assert len(fields) == 4
            assert int(fields[0]) == i % 5
            assert int(fields[3]) == i // 5
            float(fields[1])
            float(fields[2])
    os.remove(tmp_path / 'callback_dump.txt')
Exemplo n.º 3
0
def _create_default_tokenizer(store: BaseConfig) -> SentencePieceTokenizerTrainer:
    trainer = SentencePieceTokenizerTrainer(
        vocab_size=store.vocab_size,
        character_coverage=store.character_coverage,
        pretrain_sentence_count=store.pretrain_sentence_count,
        max_line_len=store.max_line_len,
        config=store,
    )
    return trainer
Exemplo n.º 4
0
def test_train_batch_sp_tok(train_df, tmp_path):
    config = TensorFlowConfig(epochs=5,
                              field_delimiter=",",
                              checkpoint_dir=tmp_path,
                              input_data_path=PATH_HOLDER,
                              learning_rate=.01)
    tokenizer = SentencePieceTokenizerTrainer(vocab_size=10000, config=config)
    batcher = DataFrameBatch(df=train_df, config=config, tokenizer=tokenizer)
    batcher.create_training_data()
    batcher.train_all_batches()

    batcher.generate_all_batch_lines(num_lines=100, max_invalid=5000)
    syn_df = batcher.batches_to_df()
    assert syn_df.shape[0] == 100
Exemplo n.º 5
0
def _create_default_tokenizer(store: BaseConfig) -> BaseTokenizerTrainer:
    """
    Create a default tokenizer. If store.vocab_size == 0, use a CharacterTokenizer.
    Otherwise use SentencePieceTokenizer
    """
    if store.vocab_size == 0:
        logging.info("Loading CharTokenizerTrainer")
        trainer = CharTokenizerTrainer(config=store)
    else:
        logging.info("Loading SentencePieceTokenizerTrainer")
        trainer = SentencePieceTokenizerTrainer(
            vocab_size=store.vocab_size,
            character_coverage=store.character_coverage,
            pretrain_sentence_count=store.pretrain_sentence_count,
            max_line_len=store.max_line_len,
            config=store,
        )
    return trainer