コード例 #1
0
    def test_network_invocation(self):
        config = bert.BertPretrainerConfig(
            encoder=encoders.TransformerEncoderConfig(vocab_size=10,
                                                      num_layers=1))
        _ = bert.instantiate_bertpretrainer_from_cfg(config)

        # Invokes with classification heads.
        config = bert.BertPretrainerConfig(
            encoder=encoders.TransformerEncoderConfig(vocab_size=10,
                                                      num_layers=1),
            cls_heads=[
                bert.ClsHeadConfig(inner_dim=10,
                                   num_classes=2,
                                   name="next_sentence")
            ])
        _ = bert.instantiate_bertpretrainer_from_cfg(config)

        with self.assertRaises(ValueError):
            config = bert.BertPretrainerConfig(
                encoder=encoders.TransformerEncoderConfig(vocab_size=10,
                                                          num_layers=1),
                cls_heads=[
                    bert.ClsHeadConfig(inner_dim=10,
                                       num_classes=2,
                                       name="next_sentence"),
                    bert.ClsHeadConfig(inner_dim=10,
                                       num_classes=2,
                                       name="next_sentence")
                ])
            _ = bert.instantiate_bertpretrainer_from_cfg(config)
コード例 #2
0
 def build_model(self):
   if self._hub_module:
     encoder_from_hub = utils.get_encoder_from_hub(self._hub_module)
     return bert.instantiate_bertpretrainer_from_cfg(
         self.task_config.model, encoder_network=encoder_from_hub)
   else:
     return bert.instantiate_bertpretrainer_from_cfg(self.task_config.model)
コード例 #3
0
    def test_task(self):
        # Saves a checkpoint.
        pretrain_cfg = bert.BertPretrainerConfig(encoder=self._encoder_config,
                                                 num_masked_tokens=20,
                                                 cls_heads=[
                                                     bert.ClsHeadConfig(
                                                         inner_dim=10,
                                                         num_classes=3,
                                                         name="next_sentence")
                                                 ])
        pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg)
        ckpt = tf.train.Checkpoint(model=pretrain_model,
                                   **pretrain_model.checkpoint_items)
        saved_path = ckpt.save(self.get_temp_dir())

        config = question_answering.QuestionAnsweringConfig(
            init_checkpoint=saved_path,
            network=self._encoder_config,
            train_data=self._train_data_config)
        task = question_answering.QuestionAnsweringTask(config)
        model = task.build_model()
        metrics = task.build_metrics()
        dataset = task.build_inputs(config.train_data)

        iterator = iter(dataset)
        optimizer = tf.keras.optimizers.SGD(lr=0.1)
        task.train_step(next(iterator), model, optimizer, metrics=metrics)
        task.validation_step(next(iterator), model, metrics=metrics)
        task.initialize(model)
コード例 #4
0
    def test_task(self):
        config = sentence_prediction.SentencePredictionConfig(
            init_checkpoint=self.get_temp_dir(),
            model=self.get_model_config(2),
            train_data=self._train_data_config)
        task = sentence_prediction.SentencePredictionTask(config)
        model = task.build_model()
        metrics = task.build_metrics()
        dataset = task.build_inputs(config.train_data)

        iterator = iter(dataset)
        optimizer = tf.keras.optimizers.SGD(lr=0.1)
        task.train_step(next(iterator), model, optimizer, metrics=metrics)
        task.validation_step(next(iterator), model, metrics=metrics)

        # Saves a checkpoint.
        pretrain_cfg = bert.BertPretrainerConfig(
            encoder=encoders.TransformerEncoderConfig(vocab_size=30522,
                                                      num_layers=1),
            num_masked_tokens=20,
            cls_heads=[
                bert.ClsHeadConfig(inner_dim=10,
                                   num_classes=3,
                                   name="next_sentence")
            ])
        pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg)
        ckpt = tf.train.Checkpoint(model=pretrain_model,
                                   **pretrain_model.checkpoint_items)
        ckpt.save(config.init_checkpoint)
        task.initialize(model)
コード例 #5
0
ファイル: bert_test.py プロジェクト: tpsgrp/python-app
 def test_checkpoint_items(self):
     config = bert.BertPretrainerConfig(
         encoder=encoders.TransformerEncoderConfig(vocab_size=10,
                                                   num_layers=1),
         cls_heads=[
             bert.ClsHeadConfig(inner_dim=10,
                                num_classes=2,
                                name="next_sentence")
         ])
     encoder = bert.instantiate_bertpretrainer_from_cfg(config)
     self.assertSameElements(encoder.checkpoint_items.keys(),
                             ["encoder", "next_sentence.pooler_dense"])
コード例 #6
0
 def build_model(self):
     if self._hub_module:
         input_word_ids = tf.keras.layers.Input(shape=(None, ),
                                                dtype=tf.int32,
                                                name='input_word_ids')
         input_mask = tf.keras.layers.Input(shape=(None, ),
                                            dtype=tf.int32,
                                            name='input_mask')
         input_type_ids = tf.keras.layers.Input(shape=(None, ),
                                                dtype=tf.int32,
                                                name='input_type_ids')
         bert_model = hub.KerasLayer(self._hub_module, trainable=True)
         pooled_output, sequence_output = bert_model(
             [input_word_ids, input_mask, input_type_ids])
         encoder_from_hub = tf.keras.Model(
             inputs=[input_word_ids, input_mask, input_type_ids],
             outputs=[sequence_output, pooled_output])
         return bert.instantiate_bertpretrainer_from_cfg(
             self.task_config.network, encoder_network=encoder_from_hub)
     else:
         return bert.instantiate_bertpretrainer_from_cfg(
             self.task_config.network)
コード例 #7
0
    def test_task(self, version_2_with_negative, tokenization):
        # Saves a checkpoint.
        pretrain_cfg = bert.BertPretrainerConfig(encoder=self._encoder_config,
                                                 num_masked_tokens=20,
                                                 cls_heads=[
                                                     bert.ClsHeadConfig(
                                                         inner_dim=10,
                                                         num_classes=3,
                                                         name="next_sentence")
                                                 ])
        pretrain_model = bert.instantiate_bertpretrainer_from_cfg(pretrain_cfg)
        ckpt = tf.train.Checkpoint(model=pretrain_model,
                                   **pretrain_model.checkpoint_items)
        saved_path = ckpt.save(self.get_temp_dir())

        config = question_answering.QuestionAnsweringConfig(
            init_checkpoint=saved_path,
            model=self._encoder_config,
            train_data=self._train_data_config,
            validation_data=self._get_validation_data_config(
                version_2_with_negative))
        self._run_task(config)
コード例 #8
0
 def build_model(self):
     return bert.instantiate_bertpretrainer_from_cfg(self.task_config.model)