Exemple #1
0
 def _dataset_fn(ctx=None):
     """Returns tf.data.Dataset for distributed BERT pretraining."""
     input_patterns = input_file_pattern.split(',')
     train_dataset = input_pipeline.create_pretrain_dataset(
         input_patterns,
         seq_length,
         max_predictions_per_seq,
         batch_size,
         is_training=True,
         input_pipeline_context=ctx)
     return train_dataset
Exemple #2
0
 def _dataset_fn(ctx=None):
     """Returns tf.data.Dataset for distributed BERT pretraining."""
     input_patterns = input_file_pattern.split(',')
     batch_size = ctx.get_per_replica_batch_size(global_batch_size)
     train_dataset = input_pipeline.create_pretrain_dataset(
         input_patterns,
         seq_length,
         max_predictions_per_seq,
         batch_size,
         is_training=True,
         input_pipeline_context=ctx,
         use_next_sentence_label=use_next_sentence_label)
     return train_dataset
Exemple #3
0
 def _dataset_fn(ctx=None):
     """Returns tf.data.Dataset for distributed BERT pretraining."""
     input_files = []
     for input_pattern in input_file_pattern.split(','):
         input_files.extend(tf.io.gfile.glob(input_pattern))
     batch_size = ctx.get_per_replica_batch_size(global_batch_size)
     train_dataset = input_pipeline.create_pretrain_dataset(
         input_files,
         seq_length,
         max_predictions_per_seq,
         batch_size,
         is_training=True,
         input_pipeline_context=ctx)
     return train_dataset
 def _dataset_fn(ctx=None):
     """Returns tf.data.Dataset for distributed BERT pretraining."""
     input_data = [f'gs://{args.bucket_name}/{args.project_name}/pretrain/pretrain_data/{args.pretrain_data}/tfrecords/{_type}/*.tfrecords']
     per_replica_batch_size = ctx.get_per_replica_batch_size(batch_size)
     dataset = input_pipeline.create_pretrain_dataset(
         input_data,
         args.max_seq_length,
         args.max_predictions_per_seq,
         per_replica_batch_size,
         is_training=is_training,
         input_pipeline_context=ctx)
     if _type == 'dev':
         # added here so that eval_steps can be arbitraily large
         dataset = dataset.repeat()
     return dataset