Example #1
0
    def _dataset_fn(ctx=None):
        del ctx

        input_files = []
        for input_pattern in input_file_pattern.split(','):
            input_files.extend(tf.io.gfile.glob(input_pattern))

        train_dataset = input_pipeline.create_pretrain_dataset(
            input_files, seq_length, max_predictions_per_seq, batch_size)
        return train_dataset
Example #2
0
def get_pretrain_input_data(input_file_pattern, seq_length,
                            max_predictions_per_seq, batch_size):
    """Returns input dataset from input file string."""

    input_files = []
    for input_pattern in input_file_pattern.split(','):
        input_files.extend(tf.io.gfile.glob(input_pattern))

    train_dataset = input_pipeline.create_pretrain_dataset(
        input_files, seq_length, max_predictions_per_seq, batch_size)
    return train_dataset
Example #3
0
    def _dataset_fn(ctx=None):
        """Returns tf.data.Dataset for distributed BERT pretraining."""
        input_files = []
        for input_pattern in input_file_pattern.split(','):
            input_files.extend(tf.io.gfile.glob(input_pattern))

        train_dataset = input_pipeline.create_pretrain_dataset(
            input_files,
            seq_length,
            max_predictions_per_seq,
            batch_size,
            is_training=True,
            input_pipeline_context=ctx)
        return train_dataset