def _dataset_fn(ctx=None): """Returns tf.data.Dataset for distributed BERT pretraining.""" batch_size = ctx.get_per_replica_batch_size( global_batch_size) if ctx else global_batch_size dataset = input_pipeline.create_classifier_dataset( input_file_pattern, max_seq_length, batch_size, is_training=is_training, input_pipeline_context=ctx) return dataset
def _dataset_fn(ctx=None): """Returns tf.data.Dataset for distributed BERT pretraining.""" batch_size = ctx.get_per_replica_batch_size( global_batch_size) if ctx else global_batch_size dataset = input_pipeline.create_classifier_dataset( tf.io.gfile.glob(input_file_pattern), max_seq_length, batch_size, is_training=is_training, input_pipeline_context=ctx, label_type=label_type, include_sample_weights=include_sample_weights) return dataset
def predict_customized(strategy, input_meta_data, bert_config, eval_data_path, num_steps): max_seq_length = input_meta_data['max_seq_length'] num_classes = input_meta_data['num_labels'] predict_dataset = input_pipeline.create_classifier_dataset( eval_data_path, input_meta_data['max_seq_length'], FLAGS.eval_batch_size, is_training=False) predict_iterator = iter( strategy.experimental_distribute_dataset(predict_dataset)) with strategy.scope(): tf.keras.mixed_precision.experimental.set_policy('float32') classifier_model, _ = (bert_models.classifier_model( bert_config, tf.float32, num_classes, max_seq_length)) checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir) logging.info('Restoring checkpoints from %s', checkpoint_path) checkpoint = tf.train.Checkpoint(model=classifier_model) checkpoint.restore(checkpoint_path).expect_partial() @tf.function def predict_step(iterator): """Predicts on distributed devices.""" def _replicated_step(inputs_d): """Replicated prediction calculation.""" inputs, label = inputs_d # x = { # 'input_word_ids': inputs['input_word_ids'], # 'input_mask': inputs['input_mask'], # 'input_type_ids': inputs['input_type_ids'], # } logits = classifier_model(inputs, training=False) return dict(logits=logits, label_ids=label, mask=inputs["is_real_example"]) outputs = strategy.experimental_run_v2(_replicated_step, args=(next(iterator), )) return tf.nest.map_structure(strategy.experimental_local_results, outputs) correct = 0 total = 0 all_results = [] for _ in range(num_steps): predictions = predict_step(predict_iterator) merged_logits = [] merged_labels = [] merged_masks = [] for result in get_raw_results(predictions): all_results.append(result) if len(all_results) % 100 == 0: logging.info('Made predictions for %d records.', len(all_results)) for logits, label_ids, mask in zip(predictions['logits'], predictions['label_ids'], predictions['mask']): merged_logits.append(logits) merged_labels.append(label_ids) merged_masks.append(mask) merged_logits = np.vstack(np.array(merged_logits)) merged_labels = np.hstack(np.array(merged_labels)) merged_masks = np.hstack(np.array(merged_masks)) real_index = np.where(np.equal(merged_masks, 1)) correct += np.sum( np.equal(np.argmax(merged_logits, axis=-1), merged_labels)) total += np.shape(real_index)[-1] accuracy = float(correct) / float(total) logging.info("Train step: %d / acc = %d/%d = %f", num_steps, correct, total, accuracy) return all_results