示例#1
0
 def _dataset_fn(ctx=None):
   """Returns tf.data.Dataset for distributed BERT pretraining."""
   batch_size = ctx.get_per_replica_batch_size(
       global_batch_size) if ctx else global_batch_size
   dataset = input_pipeline.create_classifier_dataset(
       input_file_pattern,
       max_seq_length,
       batch_size,
       is_training=is_training,
       input_pipeline_context=ctx)
   return dataset
示例#2
0
 def _dataset_fn(ctx=None):
     """Returns tf.data.Dataset for distributed BERT pretraining."""
     batch_size = ctx.get_per_replica_batch_size(
         global_batch_size) if ctx else global_batch_size
     dataset = input_pipeline.create_classifier_dataset(
         tf.io.gfile.glob(input_file_pattern),
         max_seq_length,
         batch_size,
         is_training=is_training,
         input_pipeline_context=ctx,
         label_type=label_type,
         include_sample_weights=include_sample_weights)
     return dataset
示例#3
0
def predict_customized(strategy, input_meta_data, bert_config, eval_data_path,
                       num_steps):
    max_seq_length = input_meta_data['max_seq_length']
    num_classes = input_meta_data['num_labels']
    predict_dataset = input_pipeline.create_classifier_dataset(
        eval_data_path,
        input_meta_data['max_seq_length'],
        FLAGS.eval_batch_size,
        is_training=False)
    predict_iterator = iter(
        strategy.experimental_distribute_dataset(predict_dataset))
    with strategy.scope():
        tf.keras.mixed_precision.experimental.set_policy('float32')
        classifier_model, _ = (bert_models.classifier_model(
            bert_config, tf.float32, num_classes, max_seq_length))
    checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
    logging.info('Restoring checkpoints from %s', checkpoint_path)
    checkpoint = tf.train.Checkpoint(model=classifier_model)
    checkpoint.restore(checkpoint_path).expect_partial()

    @tf.function
    def predict_step(iterator):
        """Predicts on distributed devices."""
        def _replicated_step(inputs_d):
            """Replicated prediction calculation."""
            inputs, label = inputs_d
            # x = {
            #     'input_word_ids': inputs['input_word_ids'],
            #     'input_mask': inputs['input_mask'],
            #     'input_type_ids': inputs['input_type_ids'],
            # }
            logits = classifier_model(inputs, training=False)
            return dict(logits=logits,
                        label_ids=label,
                        mask=inputs["is_real_example"])

        outputs = strategy.experimental_run_v2(_replicated_step,
                                               args=(next(iterator), ))
        return tf.nest.map_structure(strategy.experimental_local_results,
                                     outputs)

    correct = 0
    total = 0
    all_results = []
    for _ in range(num_steps):
        predictions = predict_step(predict_iterator)
        merged_logits = []
        merged_labels = []
        merged_masks = []
        for result in get_raw_results(predictions):
            all_results.append(result)
        if len(all_results) % 100 == 0:
            logging.info('Made predictions for %d records.', len(all_results))
        for logits, label_ids, mask in zip(predictions['logits'],
                                           predictions['label_ids'],
                                           predictions['mask']):
            merged_logits.append(logits)
            merged_labels.append(label_ids)
            merged_masks.append(mask)
        merged_logits = np.vstack(np.array(merged_logits))
        merged_labels = np.hstack(np.array(merged_labels))
        merged_masks = np.hstack(np.array(merged_masks))
        real_index = np.where(np.equal(merged_masks, 1))
        correct += np.sum(
            np.equal(np.argmax(merged_logits, axis=-1), merged_labels))
        total += np.shape(real_index)[-1]
    accuracy = float(correct) / float(total)
    logging.info("Train step: %d  /  acc = %d/%d = %f", num_steps, correct,
                 total, accuracy)
    return all_results