Exemplo n.º 1
0
def get_dynamic_datasets(
        dataset_name='wmt17_translate/de-en',
        eval_dataset_name=None,
        reverse_translation=True,
        shard_idx=0,
        shard_count=1,
        data_dir=None,
        vocab_path=None,
        target_vocab_size=2**15,  # 32000
        max_corpus_chars=10**7,
        batch_size=256,
        max_length=256,
        max_eval_length=256,
        paracrawl_size=0,
        is_scores_path=None,
        num_buckets=100):
    """Load and return dataset of batched examples for use during training."""
    if vocab_path is None:
        vocab_path = os.path.expanduser('~/wmt_sentencepiece_model')

    train_data, eval_data, _ = raw_wmt_datasets(
        dataset_name=dataset_name,
        eval_dataset_name=eval_dataset_name,
        reverse_translation=reverse_translation,
        shard_idx=shard_idx,
        shard_count=shard_count,
        data_dir=data_dir,
        paracrawl_size=paracrawl_size,
        shuffle_train_files=False)

    sp_tokenizer = tokenizer.load_or_train_tokenizer(
        train_data,
        vocab_path=vocab_path,
        vocab_size=target_vocab_size,
        max_corpus_chars=max_corpus_chars)
    train_data = train_data.map(tokenizer.TokenizeOp(sp_tokenizer),
                                num_parallel_calls=AUTOTUNE)
    eval_data = eval_data.map(tokenizer.TokenizeOp(sp_tokenizer),
                              num_parallel_calls=AUTOTUNE)

    train_data_manager = build_dynamic_data(train_data,
                                            batch_size=batch_size,
                                            max_length=max_length,
                                            is_scores_path=is_scores_path,
                                            num_buckets=num_buckets)

    eval_batches = preprocess_wmt_data(eval_data,
                                       shuffle=False,
                                       pack_examples=False,
                                       batch_size=batch_size,
                                       max_length=max_eval_length)

    predict_batches = preprocess_wmt_data(eval_data,
                                          shuffle=False,
                                          pack_examples=False,
                                          batch_size=batch_size,
                                          max_length=max_eval_length,
                                          drop_remainder=False)

    return train_data_manager, eval_batches, predict_batches, sp_tokenizer
Exemplo n.º 2
0
def get_wmt_is_datasets(
        n_devices,
        dataset_name='wmt17_translate/de-en',
        reverse_translation=True,
        shard_idx=0,
        shard_count=1,
        data_dir=None,
        vocab_path=None,
        target_vocab_size=2**15,  # 32000
        max_corpus_chars=10**7,
        batch_size=256,
        max_length=256,
        paracrawl_size=0):
    """Load and return dataset of batched examples for use during training."""
    if batch_size % n_devices:
        raise ValueError("Batch size %d isn't divided evenly by n_devices %d" %
                         (batch_size, n_devices))
    if vocab_path is None:
        vocab_path = os.path.expanduser('~/wmt_sentencepiece_model')

    train_data, _, _ = raw_wmt_datasets(
        dataset_name=dataset_name,
        eval_dataset_name=None,
        reverse_translation=reverse_translation,
        shard_idx=shard_idx,
        shard_count=shard_count,
        data_dir=data_dir,
        paracrawl_size=paracrawl_size,
        shuffle_train_files=False)

    # Tokenize data.
    sp_tokenizer = tokenizer.load_or_train_tokenizer(
        train_data,
        vocab_path=vocab_path,
        vocab_size=target_vocab_size,
        max_corpus_chars=max_corpus_chars)

    # Encode strings with sentencepiece tokenizer.
    train_data = train_data.map(tokenizer.TokenizeOp(sp_tokenizer),
                                num_parallel_calls=AUTOTUNE)

    train_batches = preprocess_wmt_data(train_data,
                                        shuffle=False,
                                        num_epochs=1,
                                        pack_examples=False,
                                        batch_size=batch_size,
                                        max_length=max_length,
                                        drop_remainder=False)
    # Note: we drop remainder which will truncate the training data but the
    # effect is 0.017% of the dataset so shouldn't effect model

    return train_batches, sp_tokenizer
Exemplo n.º 3
0
def get_wmt_datasets(
        dataset_name='wmt17_translate/de-en',
        eval_dataset_name=None,
        reverse_translation=True,
        shard_idx=0,
        shard_count=1,
        data_dir=None,
        vocab_path=None,
        target_vocab_size=2**15,  # 32000
        max_corpus_chars=10**7,
        batch_size=256,
        pack_examples=True,
        max_length=256,
        max_eval_length=256,
        paracrawl_size=0,
        is_scores_path=None,
        num_to_keep=-1,
        pseudo_path=None,
        shuffle_repeat_train=True,
        repeat_count=-1,
        newscommentary_size=None):
    """Load and return dataset of batched examples for use during training."""
    if vocab_path is None:
        vocab_path = os.path.expanduser('~/wmt_sentencepiece_model')

    train_data, eval_data, _ = raw_wmt_datasets(
        dataset_name=dataset_name,
        eval_dataset_name=eval_dataset_name,
        reverse_translation=reverse_translation,
        shard_idx=shard_idx,
        shard_count=shard_count,
        data_dir=data_dir,
        paracrawl_size=paracrawl_size,
        shuffle_train_files=(is_scores_path is None) and shuffle_repeat_train,
        pseudo_path=pseudo_path,
        newscommentary_size=newscommentary_size)
    # If is_score_path is None, there is no data selection so we can shuffle.
    # If it is not None, then we cannot shuffle the input files.

    # Tokenize data.
    sp_tokenizer = tokenizer.load_or_train_tokenizer(
        train_data,
        vocab_path=vocab_path,
        vocab_size=target_vocab_size,
        max_corpus_chars=max_corpus_chars)

    # Currently the pseudorefs are stored in pickle files and are pre-tokenized
    # so we would not tokenize them here. Instead we should write the
    # pseudo references to a tfrecord in the future.
    if 'pseudo' not in dataset_name:
        train_data = train_data.map(tokenizer.TokenizeOp(sp_tokenizer),
                                    num_parallel_calls=AUTOTUNE)
    eval_data = eval_data.map(tokenizer.TokenizeOp(sp_tokenizer),
                              num_parallel_calls=AUTOTUNE)

    train_ds = preprocess_wmt_data(train_data,
                                   shuffle=shuffle_repeat_train,
                                   num_epochs=repeat_count,
                                   pack_examples=pack_examples,
                                   batch_size=batch_size,
                                   max_length=max_length,
                                   is_scores_path=is_scores_path,
                                   num_to_keep=num_to_keep)

    eval_ds = preprocess_wmt_data(eval_data,
                                  shuffle=False,
                                  pack_examples=False,
                                  batch_size=batch_size,
                                  max_length=max_eval_length)

    predict_ds = preprocess_wmt_data(eval_data,
                                     shuffle=False,
                                     pack_examples=False,
                                     batch_size=batch_size,
                                     max_length=max_eval_length,
                                     drop_remainder=False)

    return train_ds, eval_ds, predict_ds, sp_tokenizer