def bert_tagging() -> cfg.ExperimentConfig: """BERT tagging task.""" config = cfg.ExperimentConfig( task=tagging.TaggingConfig( train_data=tagging_dataloader.TaggingDataConfig(), validation_data=tagging_dataloader.TaggingDataConfig( is_training=False, drop_remainder=False)), trainer=cfg.TrainerConfig( optimizer_config=optimization.OptimizationConfig({ 'optimizer': { 'type': 'adamw', 'adamw': { 'weight_decay_rate': 0.01, 'exclude_from_weight_decay': ['LayerNorm', 'layer_norm', 'bias'], } }, 'learning_rate': { 'type': 'polynomial', 'polynomial': { 'initial_learning_rate': 8e-5, 'end_learning_rate': 0.0, } }, 'warmup': { 'type': 'polynomial' } })), restrictions=[ 'task.train_data.is_training != None', 'task.validation_data.is_training != None', ]) return config
def test_predict(self): task_config = tagging.TaggingConfig( model=tagging.ModelConfig(encoder=self._encoder_config), train_data=self._train_data_config, class_names=["O", "B-PER", "I-PER"]) task = tagging.TaggingTask(task_config) model = task.build_model() test_data_path = os.path.join(self.get_temp_dir(), "test.tf_record") seq_length = 16 num_examples = 100 _create_fake_dataset(test_data_path, seq_length=seq_length, num_labels=len(task_config.class_names), num_examples=num_examples) test_data_config = tagging_dataloader.TaggingDataConfig( input_path=test_data_path, seq_length=seq_length, is_training=False, global_batch_size=16, drop_remainder=False, include_sentence_id=True) results = tagging.predict(task, test_data_config, model) self.assertLen(results, num_examples) self.assertLen(results[0], 3)
def test_load_dataset(self, include_sentence_id): seq_length = 16 batch_size = 10 train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record') _create_fake_dataset(train_data_path, seq_length, include_sentence_id) data_config = tagging_dataloader.TaggingDataConfig( input_path=train_data_path, seq_length=seq_length, global_batch_size=batch_size, include_sentence_id=include_sentence_id) dataset = tagging_dataloader.TaggingDataLoader(data_config).load() features, labels = next(iter(dataset)) expected_keys = ['input_word_ids', 'input_mask', 'input_type_ids'] if include_sentence_id: expected_keys.extend(['sentence_id', 'sub_sentence_id']) self.assertCountEqual(expected_keys, features.keys()) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(labels.shape, (batch_size, seq_length)) if include_sentence_id: self.assertEqual(features['sentence_id'].shape, (batch_size, )) self.assertEqual(features['sub_sentence_id'].shape, (batch_size, ))
def write_tagging(task, model, input_file, output_file, predict_batch_size, seq_length): """Makes tagging predictions and writes to output file.""" data_config = tagging_dataloader.TaggingDataConfig( input_path=input_file, is_training=False, seq_length=seq_length, global_batch_size=predict_batch_size, drop_remainder=False, include_sentence_id=True) results = tagging.predict(task, data_config, model) class_names = task.task_config.class_names last_sentence_id = -1 with tf.io.gfile.GFile(output_file, 'w') as writer: for sentence_id, _, predict_ids in results: token_labels = [class_names[x] for x in predict_ids] assert sentence_id == last_sentence_id or ( sentence_id == last_sentence_id + 1) if sentence_id != last_sentence_id and last_sentence_id != -1: writer.write('\n') writer.write('\n'.join(token_labels)) writer.write('\n') last_sentence_id = sentence_id
def _infer(model, task, test_data_path, train_with_additional_labels, batch_size): """Computes the predicted label sequence using the trained model.""" test_data_config = tagging_dataloader.TaggingDataConfig( input_path=test_data_path, seq_length=128, global_batch_size=batch_size, is_training=False, include_sentence_id=True, drop_remainder=False) predictions = _predict(task, test_data_config, model) merged_probabilities = [] for _, part_id, predicted_probabilies in predictions: if part_id == 0: merged_probabilities.append(predicted_probabilies) else: merged_probabilities[-1].extend(predicted_probabilies) merged_predictions = [] for i, probabilities in enumerate(merged_probabilities): assert not np.isnan(probabilities).any(), ( "There was an error during decoding. Try reducing the batch size." " First error in sentence %d" % i) if FLAGS.viterbi_decoding: prediction = _viterbi_decoding(probabilities, train_with_additional_labels) else: prediction = _greedy_decoding(probabilities) merged_predictions.append(prediction) return merged_predictions
def test_load_dataset(self): seq_length = 16 batch_size = 10 train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record') _create_fake_dataset(train_data_path, seq_length) data_config = tagging_dataloader.TaggingDataConfig( input_path=train_data_path, seq_length=seq_length, global_batch_size=batch_size) dataset = tagging_dataloader.TaggingDataLoader(data_config).load() features, labels = next(iter(dataset)) self.assertCountEqual( ['input_word_ids', 'input_mask', 'input_type_ids'], features.keys()) self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length)) self.assertEqual(features['input_mask'].shape, (batch_size, seq_length)) self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length)) self.assertEqual(labels.shape, (batch_size, seq_length))
def setUp(self): super(TaggingTest, self).setUp() self._encoder_config = encoders.EncoderConfig( bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)) self._train_data_config = tagging_dataloader.TaggingDataConfig( input_path="dummy", seq_length=128, global_batch_size=1)
def train(): """Trains a model.""" if FLAGS.tpu_address is not None: if FLAGS.plateau_lr_reduction != 1.0: raise NotImplementedError( "Learning rate reduction cannot be used on TPUs, because the" " validation set cannot be evaluated.") resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu_address) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) else: if (FLAGS.plateau_lr_reduction != 1.0 and FLAGS.validation_data_path is None): raise ValueError( "In order to reduce the learning rate on plateaus, a validation" " set must be specified.") strategy = tf.distribute.get_strategy() model_config = ModelSetupConfig( size=ModelSize[FLAGS.size.upper()], case_sensitive=FLAGS.case_sensitive, pretrained=FLAGS.pretrained, train_with_additional_labels=FLAGS.train_with_additional_labels) with strategy.scope(): train_data_config = tagging_dataloader.TaggingDataConfig( input_path=FLAGS.train_data_path, seq_length=128, global_batch_size=FLAGS.batch_size) if FLAGS.validation_data_path is not None: validation_data_config = tagging_dataloader.TaggingDataConfig( input_path=FLAGS.validation_data_path, seq_length=128, global_batch_size=FLAGS.batch_size, is_training=False) else: validation_data_config = None label_list = LABELS if model_config.train_with_additional_labels: label_list = LABELS + ADDITIONAL_LABELS tagging_config = get_tagging_config( model_config, label_list=label_list, train_data_config=train_data_config, validation_data_config=validation_data_config) task = ConfigurableTrainingTaggingTask(tagging_config) model = task.build_model(FLAGS.train_last_layer_only) if FLAGS.optimizer == "sgd": optimizer = tf.keras.optimizers.SGD(lr=FLAGS.learning_rate) elif FLAGS.optimizer == "adam": optimizer = tf.keras.optimizers.Adam( learning_rate=FLAGS.learning_rate) else: raise ValueError("Only SGD and Adam are supported optimizers.") iterations_per_epoch = FLAGS.train_size // FLAGS.batch_size model.compile( optimizer=optimizer, metrics=[ tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy") ], steps_per_execution=iterations_per_epoch) model.train_step = functools.partial(task.train_step, model=model, optimizer=model.optimizer) dataset_train = task.build_inputs(tagging_config.train_data) checkpoint = ModelCheckpoint(FLAGS.save_path + "/model_{epoch:02d}", verbose=1, save_best_only=False, save_weights_only=True, period=1) callbacks = [checkpoint] additional_fit_parameters = {} if FLAGS.plateau_lr_reduction != 1.0: dataset_validation = task.build_inputs( tagging_config.validation_data) reduce_lr = ReduceLROnPlateau(monitor="val_loss", factor=FLAGS.plateau_lr_reduction, patience=FLAGS.plateau_patience, verbose=1) callbacks.append(reduce_lr) additional_fit_parameters["validation_data"] = dataset_validation model.test_step = functools.partial(task.validation_step, model=model) model.fit(dataset_train, epochs=FLAGS.epochs, steps_per_epoch=iterations_per_epoch, callbacks=callbacks, **additional_fit_parameters)