コード例 #1
0
    def test_network_invocation(self):
        config = bert.BertPretrainerConfig(
            encoder=encoders.TransformerEncoderConfig(vocab_size=10,
                                                      num_layers=1))
        _ = bert.instantiate_bertpretrainer_from_cfg(config)

        # Invokes with classification heads.
        config = bert.BertPretrainerConfig(
            encoder=encoders.TransformerEncoderConfig(vocab_size=10,
                                                      num_layers=1),
            cls_heads=[
                bert.ClsHeadConfig(inner_dim=10,
                                   num_classes=2,
                                   name="next_sentence")
            ])
        _ = bert.instantiate_bertpretrainer_from_cfg(config)

        with self.assertRaises(ValueError):
            config = bert.BertPretrainerConfig(
                encoder=encoders.TransformerEncoderConfig(vocab_size=10,
                                                          num_layers=1),
                cls_heads=[
                    bert.ClsHeadConfig(inner_dim=10,
                                       num_classes=2,
                                       name="next_sentence"),
                    bert.ClsHeadConfig(inner_dim=10,
                                       num_classes=2,
                                       name="next_sentence")
                ])
            _ = bert.instantiate_bertpretrainer_from_cfg(config)
コード例 #2
0
    def test_task(self):
        # Saves a checkpoint.
        pretrain_cfg = bert.BertPretrainerConfig(encoder=self._encoder_config,
                                                 num_masked_tokens=20,
                                                 cls_heads=[
                                                     bert.ClsHeadConfig(
                                                         inner_dim=10,
                                                         num_classes=3,
                                                         name="next_sentence")
                                                 ])
        pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg)
        ckpt = tf.train.Checkpoint(model=pretrain_model,
                                   **pretrain_model.checkpoint_items)
        saved_path = ckpt.save(self.get_temp_dir())

        config = question_answering.QuestionAnsweringConfig(
            init_checkpoint=saved_path,
            network=self._encoder_config,
            train_data=self._train_data_config)
        task = question_answering.QuestionAnsweringTask(config)
        model = task.build_model()
        metrics = task.build_metrics()
        dataset = task.build_inputs(config.train_data)

        iterator = iter(dataset)
        optimizer = tf.keras.optimizers.SGD(lr=0.1)
        task.train_step(next(iterator), model, optimizer, metrics=metrics)
        task.validation_step(next(iterator), model, metrics=metrics)
        task.initialize(model)
コード例 #3
0
  def test_task(self):
    config = sentence_prediction.SentencePredictionConfig(
        init_checkpoint=self.get_temp_dir(),
        model=self.get_model_config(2),
        train_data=self._train_data_config)
    task = sentence_prediction.SentencePredictionTask(config)
    model = task.build_model()
    metrics = task.build_metrics()
    dataset = task.build_inputs(config.train_data)

    iterator = iter(dataset)
    optimizer = tf.keras.optimizers.SGD(lr=0.1)
    task.train_step(next(iterator), model, optimizer, metrics=metrics)
    task.validation_step(next(iterator), model, metrics=metrics)

    # Saves a checkpoint.
    pretrain_cfg = bert.BertPretrainerConfig(
        encoder=encoders.TransformerEncoderConfig(
            vocab_size=30522, num_layers=1),
        cls_heads=[
            bert.ClsHeadConfig(
                inner_dim=10, num_classes=3, name="next_sentence")
        ])
    pretrain_model = bert.instantiate_pretrainer_from_cfg(pretrain_cfg)
    ckpt = tf.train.Checkpoint(
        model=pretrain_model, **pretrain_model.checkpoint_items)
    ckpt.save(config.init_checkpoint)
    task.initialize(model)
コード例 #4
0
    def test_task(self):
        config = masked_lm.MaskedLMConfig(
            init_checkpoint=self.get_temp_dir(),
            model=bert.BertPretrainerConfig(
                encoders.TransformerEncoderConfig(vocab_size=30522,
                                                  num_layers=1),
                num_masked_tokens=20,
                cls_heads=[
                    bert.ClsHeadConfig(inner_dim=10,
                                       num_classes=2,
                                       name="next_sentence")
                ]),
            train_data=pretrain_dataloader.BertPretrainDataConfig(
                input_path="dummy",
                max_predictions_per_seq=20,
                seq_length=128,
                global_batch_size=1))
        task = masked_lm.MaskedLMTask(config)
        model = task.build_model()
        metrics = task.build_metrics()
        dataset = task.build_inputs(config.train_data)

        iterator = iter(dataset)
        optimizer = tf.keras.optimizers.SGD(lr=0.1)
        task.train_step(next(iterator), model, optimizer, metrics=metrics)
        task.validation_step(next(iterator), model, metrics=metrics)

        # Saves a checkpoint.
        ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
        ckpt.save(config.init_checkpoint)
        task.initialize(model)
コード例 #5
0
ファイル: masked_lm.py プロジェクト: yaosi912/models
class MaskedLMConfig(cfg.TaskConfig):
  """The model config."""
  network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[
      bert.ClsHeadConfig(
          inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
  ])
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()
コード例 #6
0
 def get_model_config(self, num_classes):
     return bert.BertPretrainerConfig(
         encoder=encoders.TransformerEncoderConfig(vocab_size=30522,
                                                   num_layers=1),
         num_masked_tokens=0,
         cls_heads=[
             bert.ClsHeadConfig(inner_dim=10,
                                num_classes=num_classes,
                                name="sentence_prediction")
         ])
コード例 #7
0
ファイル: bert_test.py プロジェクト: tpsgrp/python-app
 def test_checkpoint_items(self):
     config = bert.BertPretrainerConfig(
         encoder=encoders.TransformerEncoderConfig(vocab_size=10,
                                                   num_layers=1),
         cls_heads=[
             bert.ClsHeadConfig(inner_dim=10,
                                num_classes=2,
                                name="next_sentence")
         ])
     encoder = bert.instantiate_bertpretrainer_from_cfg(config)
     self.assertSameElements(encoder.checkpoint_items.keys(),
                             ["encoder", "next_sentence.pooler_dense"])
コード例 #8
0
 def setUp(self):
     super(SentencePredictionTaskTest, self).setUp()
     self._network_config = bert.BertPretrainerConfig(
         encoder=encoders.TransformerEncoderConfig(vocab_size=30522,
                                                   num_layers=1),
         num_masked_tokens=0,
         cls_heads=[
             bert.ClsHeadConfig(inner_dim=10,
                                num_classes=3,
                                name="sentence_prediction")
         ])
     self._train_data_config = bert.SentencePredictionDataConfig(
         input_path="dummy", seq_length=128, global_batch_size=1)
コード例 #9
0
ファイル: sentence_prediction.py プロジェクト: zzf2014/models
class SentencePredictionConfig(cfg.TaskConfig):
    """The model config."""
    # At most one of `pretrain_checkpoint_dir` and `hub_module_url` can
    # be specified.
    pretrain_checkpoint_dir: str = ''
    hub_module_url: str = ''
    network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
        num_masked_tokens=0,
        cls_heads=[
            bert.ClsHeadConfig(inner_dim=768,
                               num_classes=3,
                               dropout_rate=0.1,
                               name='sentence_prediction')
        ])
    train_data: cfg.DataConfig = cfg.DataConfig()
    validation_data: cfg.DataConfig = cfg.DataConfig()
コード例 #10
0
 def test_task_with_hub(self):
     hub_module_url = self._export_bert_tfhub()
     config = sentence_prediction.SentencePredictionConfig(
         hub_module_url=hub_module_url,
         network=bert.BertPretrainerConfig(
             encoders.TransformerEncoderConfig(vocab_size=30522,
                                               num_layers=1),
             num_masked_tokens=0,
             cls_heads=[
                 bert.ClsHeadConfig(inner_dim=10,
                                    num_classes=3,
                                    name="sentence_prediction")
             ]),
         train_data=bert.BertSentencePredictionDataConfig(
             input_path="dummy", seq_length=128, global_batch_size=10))
     self._run_task(config)
コード例 #11
0
class SentencePredictionConfig(cfg.TaskConfig):
    """The model config."""
    # At most one of `init_checkpoint` and `hub_module_url` can
    # be specified.
    init_checkpoint: str = ''
    hub_module_url: str = ''
    metric_type: str = 'accuracy'
    network: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
        num_masked_tokens=0,  # No masked language modeling head.
        cls_heads=[
            bert.ClsHeadConfig(inner_dim=768,
                               num_classes=3,
                               dropout_rate=0.1,
                               name='sentence_prediction')
        ])
    train_data: cfg.DataConfig = cfg.DataConfig()
    validation_data: cfg.DataConfig = cfg.DataConfig()
コード例 #12
0
  def test_task(self, version_2_with_negative, tokenization):
    # Saves a checkpoint.
    pretrain_cfg = bert.BertPretrainerConfig(
        encoder=self._encoder_config,
        cls_heads=[
            bert.ClsHeadConfig(
                inner_dim=10, num_classes=3, name="next_sentence")
        ])
    pretrain_model = bert.instantiate_pretrainer_from_cfg(pretrain_cfg)
    ckpt = tf.train.Checkpoint(
        model=pretrain_model, **pretrain_model.checkpoint_items)
    saved_path = ckpt.save(self.get_temp_dir())

    config = question_answering.QuestionAnsweringConfig(
        init_checkpoint=saved_path,
        model=question_answering.ModelConfig(encoder=self._encoder_config),
        train_data=self._train_data_config,
        validation_data=self._get_validation_data_config(
            version_2_with_negative))
    self._run_task(config)
コード例 #13
0
    def test_task(self):
        config = sentence_prediction.SentencePredictionConfig(
            network=bert.BertPretrainerConfig(
                encoders.TransformerEncoderConfig(vocab_size=30522,
                                                  num_layers=1),
                num_masked_tokens=0,
                cls_heads=[
                    bert.ClsHeadConfig(inner_dim=10,
                                       num_classes=3,
                                       name="sentence_prediction")
                ]),
            train_data=bert.BertSentencePredictionDataConfig(
                input_path="dummy", seq_length=128, global_batch_size=1))
        task = sentence_prediction.SentencePredictionTask(config)
        model = task.build_model()
        metrics = task.build_metrics()
        dataset = task.build_inputs(config.train_data)

        iterator = iter(dataset)
        optimizer = tf.keras.optimizers.SGD(lr=0.1)
        task.train_step(next(iterator), model, optimizer, metrics=metrics)
        task.validation_step(next(iterator), model, metrics=metrics)
コード例 #14
0
ファイル: masked_lm.py プロジェクト: xjx0524/models
from official.core import task_factory
>>>>>>> a811a3b7e640722318ad868c99feddf3f3063e36
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.data import data_loader_factory


@dataclasses.dataclass
class MaskedLMConfig(cfg.TaskConfig):
  """The model config."""
<<<<<<< HEAD
  init_checkpoint: str = ''
=======
>>>>>>> a811a3b7e640722318ad868c99feddf3f3063e36
  model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(cls_heads=[
      bert.ClsHeadConfig(
          inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
  ])
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()


<<<<<<< HEAD
@base_task.register_task_cls(MaskedLMConfig)
=======
@task_factory.register_task_cls(MaskedLMConfig)
>>>>>>> a811a3b7e640722318ad868c99feddf3f3063e36
class MaskedLMTask(base_task.Task):
  """Mock task object for testing."""

  def build_model(self, params=None):
    params = params or self.task_config.model