Beispiel #1
0
def bert_tagging() -> cfg.ExperimentConfig:
    """BERT tagging task."""
    config = cfg.ExperimentConfig(
        task=tagging.TaggingConfig(
            train_data=tagging_dataloader.TaggingDataConfig(),
            validation_data=tagging_dataloader.TaggingDataConfig(
                is_training=False, drop_remainder=False)),
        trainer=cfg.TrainerConfig(
            optimizer_config=optimization.OptimizationConfig({
                'optimizer': {
                    'type': 'adamw',
                    'adamw': {
                        'weight_decay_rate':
                        0.01,
                        'exclude_from_weight_decay':
                        ['LayerNorm', 'layer_norm', 'bias'],
                    }
                },
                'learning_rate': {
                    'type': 'polynomial',
                    'polynomial': {
                        'initial_learning_rate': 8e-5,
                        'end_learning_rate': 0.0,
                    }
                },
                'warmup': {
                    'type': 'polynomial'
                }
            })),
        restrictions=[
            'task.train_data.is_training != None',
            'task.validation_data.is_training != None',
        ])
    return config
Beispiel #2
0
    def test_predict(self):
        task_config = tagging.TaggingConfig(
            model=tagging.ModelConfig(encoder=self._encoder_config),
            train_data=self._train_data_config,
            class_names=["O", "B-PER", "I-PER"])
        task = tagging.TaggingTask(task_config)
        model = task.build_model()

        test_data_path = os.path.join(self.get_temp_dir(), "test.tf_record")
        seq_length = 16
        num_examples = 100
        _create_fake_dataset(test_data_path,
                             seq_length=seq_length,
                             num_labels=len(task_config.class_names),
                             num_examples=num_examples)
        test_data_config = tagging_dataloader.TaggingDataConfig(
            input_path=test_data_path,
            seq_length=seq_length,
            is_training=False,
            global_batch_size=16,
            drop_remainder=False,
            include_sentence_id=True)

        results = tagging.predict(task, test_data_config, model)
        self.assertLen(results, num_examples)
        self.assertLen(results[0], 3)
Beispiel #3
0
    def test_load_dataset(self, include_sentence_id):
        seq_length = 16
        batch_size = 10
        train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
        _create_fake_dataset(train_data_path, seq_length, include_sentence_id)
        data_config = tagging_dataloader.TaggingDataConfig(
            input_path=train_data_path,
            seq_length=seq_length,
            global_batch_size=batch_size,
            include_sentence_id=include_sentence_id)

        dataset = tagging_dataloader.TaggingDataLoader(data_config).load()
        features, labels = next(iter(dataset))

        expected_keys = ['input_word_ids', 'input_mask', 'input_type_ids']
        if include_sentence_id:
            expected_keys.extend(['sentence_id', 'sub_sentence_id'])
        self.assertCountEqual(expected_keys, features.keys())

        self.assertEqual(features['input_word_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_mask'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_type_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(labels.shape, (batch_size, seq_length))
        if include_sentence_id:
            self.assertEqual(features['sentence_id'].shape, (batch_size, ))
            self.assertEqual(features['sub_sentence_id'].shape, (batch_size, ))
def write_tagging(task, model, input_file, output_file, predict_batch_size,
                  seq_length):
  """Makes tagging predictions and writes to output file."""
  data_config = tagging_dataloader.TaggingDataConfig(
      input_path=input_file,
      is_training=False,
      seq_length=seq_length,
      global_batch_size=predict_batch_size,
      drop_remainder=False,
      include_sentence_id=True)
  results = tagging.predict(task, data_config, model)
  class_names = task.task_config.class_names
  last_sentence_id = -1

  with tf.io.gfile.GFile(output_file, 'w') as writer:
    for sentence_id, _, predict_ids in results:
      token_labels = [class_names[x] for x in predict_ids]
      assert sentence_id == last_sentence_id or (
          sentence_id == last_sentence_id + 1)

      if sentence_id != last_sentence_id and last_sentence_id != -1:
        writer.write('\n')

      writer.write('\n'.join(token_labels))
      writer.write('\n')
      last_sentence_id = sentence_id
Beispiel #5
0
def _infer(model, task, test_data_path, train_with_additional_labels,
           batch_size):
    """Computes the predicted label sequence using the trained model."""
    test_data_config = tagging_dataloader.TaggingDataConfig(
        input_path=test_data_path,
        seq_length=128,
        global_batch_size=batch_size,
        is_training=False,
        include_sentence_id=True,
        drop_remainder=False)
    predictions = _predict(task, test_data_config, model)

    merged_probabilities = []
    for _, part_id, predicted_probabilies in predictions:
        if part_id == 0:
            merged_probabilities.append(predicted_probabilies)
        else:
            merged_probabilities[-1].extend(predicted_probabilies)

    merged_predictions = []
    for i, probabilities in enumerate(merged_probabilities):
        assert not np.isnan(probabilities).any(), (
            "There was an error during decoding. Try reducing the batch size."
            " First error in sentence %d" % i)
        if FLAGS.viterbi_decoding:
            prediction = _viterbi_decoding(probabilities,
                                           train_with_additional_labels)
        else:
            prediction = _greedy_decoding(probabilities)
        merged_predictions.append(prediction)

    return merged_predictions
Beispiel #6
0
    def test_load_dataset(self):
        seq_length = 16
        batch_size = 10
        train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
        _create_fake_dataset(train_data_path, seq_length)
        data_config = tagging_dataloader.TaggingDataConfig(
            input_path=train_data_path,
            seq_length=seq_length,
            global_batch_size=batch_size)

        dataset = tagging_dataloader.TaggingDataLoader(data_config).load()
        features, labels = next(iter(dataset))
        self.assertCountEqual(
            ['input_word_ids', 'input_mask', 'input_type_ids'],
            features.keys())
        self.assertEqual(features['input_word_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_mask'].shape,
                         (batch_size, seq_length))
        self.assertEqual(features['input_type_ids'].shape,
                         (batch_size, seq_length))
        self.assertEqual(labels.shape, (batch_size, seq_length))
Beispiel #7
0
 def setUp(self):
     super(TaggingTest, self).setUp()
     self._encoder_config = encoders.EncoderConfig(
         bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1))
     self._train_data_config = tagging_dataloader.TaggingDataConfig(
         input_path="dummy", seq_length=128, global_batch_size=1)
Beispiel #8
0
def train():
    """Trains a model."""
    if FLAGS.tpu_address is not None:
        if FLAGS.plateau_lr_reduction != 1.0:
            raise NotImplementedError(
                "Learning rate reduction cannot be used on TPUs, because the"
                " validation set cannot be evaluated.")
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu_address)
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    else:
        if (FLAGS.plateau_lr_reduction != 1.0
                and FLAGS.validation_data_path is None):
            raise ValueError(
                "In order to reduce the learning rate on plateaus, a validation"
                " set must be specified.")
        strategy = tf.distribute.get_strategy()

    model_config = ModelSetupConfig(
        size=ModelSize[FLAGS.size.upper()],
        case_sensitive=FLAGS.case_sensitive,
        pretrained=FLAGS.pretrained,
        train_with_additional_labels=FLAGS.train_with_additional_labels)

    with strategy.scope():
        train_data_config = tagging_dataloader.TaggingDataConfig(
            input_path=FLAGS.train_data_path,
            seq_length=128,
            global_batch_size=FLAGS.batch_size)
        if FLAGS.validation_data_path is not None:
            validation_data_config = tagging_dataloader.TaggingDataConfig(
                input_path=FLAGS.validation_data_path,
                seq_length=128,
                global_batch_size=FLAGS.batch_size,
                is_training=False)
        else:
            validation_data_config = None

        label_list = LABELS
        if model_config.train_with_additional_labels:
            label_list = LABELS + ADDITIONAL_LABELS
        tagging_config = get_tagging_config(
            model_config,
            label_list=label_list,
            train_data_config=train_data_config,
            validation_data_config=validation_data_config)
        task = ConfigurableTrainingTaggingTask(tagging_config)
        model = task.build_model(FLAGS.train_last_layer_only)
        if FLAGS.optimizer == "sgd":
            optimizer = tf.keras.optimizers.SGD(lr=FLAGS.learning_rate)
        elif FLAGS.optimizer == "adam":
            optimizer = tf.keras.optimizers.Adam(
                learning_rate=FLAGS.learning_rate)
        else:
            raise ValueError("Only SGD and Adam are supported optimizers.")

        iterations_per_epoch = FLAGS.train_size // FLAGS.batch_size
        model.compile(
            optimizer=optimizer,
            metrics=[
                tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")
            ],
            steps_per_execution=iterations_per_epoch)
        model.train_step = functools.partial(task.train_step,
                                             model=model,
                                             optimizer=model.optimizer)
        dataset_train = task.build_inputs(tagging_config.train_data)

        checkpoint = ModelCheckpoint(FLAGS.save_path + "/model_{epoch:02d}",
                                     verbose=1,
                                     save_best_only=False,
                                     save_weights_only=True,
                                     period=1)
        callbacks = [checkpoint]

        additional_fit_parameters = {}
        if FLAGS.plateau_lr_reduction != 1.0:
            dataset_validation = task.build_inputs(
                tagging_config.validation_data)
            reduce_lr = ReduceLROnPlateau(monitor="val_loss",
                                          factor=FLAGS.plateau_lr_reduction,
                                          patience=FLAGS.plateau_patience,
                                          verbose=1)
            callbacks.append(reduce_lr)
            additional_fit_parameters["validation_data"] = dataset_validation
            model.test_step = functools.partial(task.validation_step,
                                                model=model)

        model.fit(dataset_train,
                  epochs=FLAGS.epochs,
                  steps_per_epoch=iterations_per_epoch,
                  callbacks=callbacks,
                  **additional_fit_parameters)