示例#1
0
 def setUp(self):
     super(ModelBuilderTest, self).setUp()
     self.pretrainer_config = params.PretrainerModelParams(
         encoder=encoders.EncoderConfig(type='mobilebert'))
示例#2
0
class PretrainerConfig(base_config.Config):
  """Pretrainer configuration."""
  encoder: encoders.EncoderConfig = encoders.EncoderConfig()
  cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
  mlm_activation: str = "gelu"
  mlm_initializer_range: float = 0.02
示例#3
0
 def get_model_config(self, num_classes):
     return sentence_prediction.ModelConfig(encoder=encoders.EncoderConfig(
         bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)),
                                            num_classes=num_classes)
示例#4
0
 def setUp(self):
   super(TaggingTest, self).setUp()
   self._encoder_config = encoders.EncoderConfig(
       bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1))
   self._train_data_config = tagging_data_loader.TaggingDataConfig(
       input_path="dummy", seq_length=128, global_batch_size=1)
    def test_distribution_strategy(self, distribution_strategy):
        max_seq_length = 128
        batch_size = 8
        input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
        _create_fake_dataset(input_path,
                             seq_length=60,
                             num_masked_tokens=20,
                             max_seq_length=max_seq_length,
                             num_examples=batch_size)
        data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
            is_training=False,
            input_path=input_path,
            seq_bucket_lengths=[64, 128],
            global_batch_size=batch_size)
        dataloader = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
            data_config)
        distributed_ds = orbit.utils.make_distributed_dataset(
            distribution_strategy, dataloader.load)
        train_iter = iter(distributed_ds)
        with distribution_strategy.scope():
            config = masked_lm.MaskedLMConfig(
                init_checkpoint=self.get_temp_dir(),
                model=bert.PretrainerConfig(
                    encoders.EncoderConfig(bert=encoders.BertEncoderConfig(
                        vocab_size=30522, num_layers=1)),
                    cls_heads=[
                        bert.ClsHeadConfig(inner_dim=10,
                                           num_classes=2,
                                           name='next_sentence')
                    ]),
                train_data=data_config)
            task = masked_lm.MaskedLMTask(config)
            model = task.build_model()
            metrics = task.build_metrics()

        @tf.function
        def step_fn(features):
            return task.validation_step(features, model, metrics=metrics)

        distributed_outputs = distribution_strategy.run(
            step_fn, args=(next(train_iter), ))
        local_results = tf.nest.map_structure(
            distribution_strategy.experimental_local_results,
            distributed_outputs)
        logging.info('Dynamic padding:  local_results= %s', str(local_results))
        dynamic_metrics = {}
        for metric in metrics:
            dynamic_metrics[metric.name] = metric.result()

        data_config = pretrain_dataloader.BertPretrainDataConfig(
            is_training=False,
            input_path=input_path,
            seq_length=max_seq_length,
            max_predictions_per_seq=20,
            global_batch_size=batch_size)
        dataloader = pretrain_dataloader.BertPretrainDataLoader(data_config)
        distributed_ds = orbit.utils.make_distributed_dataset(
            distribution_strategy, dataloader.load)
        train_iter = iter(distributed_ds)
        with distribution_strategy.scope():
            metrics = task.build_metrics()

        @tf.function
        def step_fn_b(features):
            return task.validation_step(features, model, metrics=metrics)

        distributed_outputs = distribution_strategy.run(
            step_fn_b, args=(next(train_iter), ))
        local_results = tf.nest.map_structure(
            distribution_strategy.experimental_local_results,
            distributed_outputs)
        logging.info('Static padding:  local_results= %s', str(local_results))
        static_metrics = {}
        for metric in metrics:
            static_metrics[metric.name] = metric.result()
        for key in static_metrics:
            # We need to investigate the differences on losses.
            if key != 'next_sentence_loss':
                self.assertEqual(dynamic_metrics[key], static_metrics[key])
示例#6
0
class ModelConfig(base_config.Config):
    """A base span labeler configuration."""
    encoder: encoders.EncoderConfig = encoders.EncoderConfig()
示例#7
0
class ModelConfig(base_config.Config):
    """A classifier/regressor configuration."""
    num_classes: int = 0
    use_encoder_pooler: bool = False
    encoder: encoders.EncoderConfig = encoders.EncoderConfig()
class ModelConfig(base_config.Config):
    """A base span labeler configuration."""
    encoder: encoders.EncoderConfig = encoders.EncoderConfig()
    head_dropout: float = 0.1
    head_initializer_range: float = 0.02
示例#9
0
 def get_model_config(self):
   return dual_encoder.ModelConfig(
       max_sequence_length=32,
       encoder=encoders.EncoderConfig(
           bert=encoders.BertEncoderConfig(vocab_size=30522, num_layers=1)))
示例#10
0
  def prepare_config(self, teacher_block_num, student_block_num,
                     transfer_teacher_layers):
    # using small model for testing
    task_config = distillation.BertDistillationTaskConfig(
        teacher_model=bert.PretrainerConfig(
            encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=teacher_block_num)),
            cls_heads=[
                bert.ClsHeadConfig(
                    inner_dim=256,
                    num_classes=2,
                    dropout_rate=0.1,
                    name='next_sentence')
            ],
            mlm_activation='gelu'),
        student_model=bert.PretrainerConfig(
            encoder=encoders.EncoderConfig(
                type='mobilebert',
                mobilebert=encoders.MobileBertEncoderConfig(
                    num_blocks=student_block_num)),
            cls_heads=[
                bert.ClsHeadConfig(
                    inner_dim=256,
                    num_classes=2,
                    dropout_rate=0.1,
                    name='next_sentence')
            ],
            mlm_activation='relu'),
        train_data=pretrain_dataloader.BertPretrainDataConfig(
            input_path='dummy',
            max_predictions_per_seq=76,
            seq_length=512,
            global_batch_size=10),
        validation_data=pretrain_dataloader.BertPretrainDataConfig(
            input_path='dummy',
            max_predictions_per_seq=76,
            seq_length=512,
            global_batch_size=10))

    # set only 1 step for each stage
    progressive_config = distillation.BertDistillationProgressiveConfig()
    progressive_config.layer_wise_distill_config.transfer_teacher_layers = (
        transfer_teacher_layers)
    progressive_config.layer_wise_distill_config.num_steps = 1
    progressive_config.pretrain_distill_config.num_steps = 1

    optimization_config = optimization.OptimizationConfig(
        optimizer=optimization.OptimizerConfig(
            type='lamb',
            lamb=optimization.LAMBConfig(
                weight_decay_rate=0.0001,
                exclude_from_weight_decay=[
                    'LayerNorm', 'layer_norm', 'bias', 'no_norm'
                ])),
        learning_rate=optimization.LrConfig(
            type='polynomial',
            polynomial=optimization.PolynomialLrConfig(
                initial_learning_rate=1.5e-3,
                decay_steps=10000,
                end_learning_rate=1.5e-3)),
        warmup=optimization.WarmupConfig(
            type='linear',
            linear=optimization.LinearWarmupConfig(warmup_learning_rate=0)))

    exp_config = cfg.ExperimentConfig(
        task=task_config,
        trainer=prog_trainer_lib.ProgressiveTrainerConfig(
            progressive=progressive_config,
            optimizer_config=optimization_config))

    # Create a teacher model checkpoint.
    teacher_encoder = encoders.build_encoder(task_config.teacher_model.encoder)
    pretrainer_config = task_config.teacher_model
    if pretrainer_config.cls_heads:
      teacher_cls_heads = [
          layers.ClassificationHead(**cfg.as_dict())
          for cfg in pretrainer_config.cls_heads
      ]
    else:
      teacher_cls_heads = []

    masked_lm = layers.MobileBertMaskedLM(
        embedding_table=teacher_encoder.get_embedding_table(),
        activation=tf_utils.get_activation(pretrainer_config.mlm_activation),
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=pretrainer_config.mlm_initializer_range),
        name='cls/predictions')
    teacher_pretrainer = models.BertPretrainerV2(
        encoder_network=teacher_encoder,
        classification_heads=teacher_cls_heads,
        customized_masked_lm=masked_lm)

    # The model variables will be created after the forward call.
    _ = teacher_pretrainer(teacher_pretrainer.inputs)
    teacher_pretrainer_ckpt = tf.train.Checkpoint(
        **teacher_pretrainer.checkpoint_items)
    teacher_ckpt_path = os.path.join(self.get_temp_dir(), 'teacher_model.ckpt')
    teacher_pretrainer_ckpt.save(teacher_ckpt_path)
    exp_config.task.teacher_model_init_checkpoint = self.get_temp_dir()

    return exp_config
    def test_sentence_prediction(self, use_v2_feature_names):
        if use_v2_feature_names:
            input_word_ids_field = "input_word_ids"
            input_type_ids_field = "input_type_ids"
        else:
            input_word_ids_field = "input_ids"
            input_type_ids_field = "segment_ids"

        config = sentence_prediction.SentencePredictionConfig(
            model=sentence_prediction.ModelConfig(
                encoder=encoders.EncoderConfig(bert=encoders.BertEncoderConfig(
                    vocab_size=30522, num_layers=1)),
                num_classes=2))
        task = sentence_prediction.SentencePredictionTask(config)
        model = task.build_model()
        params = serving_modules.SentencePrediction.Params(
            inputs_only=True,
            parse_sequence_length=10,
            use_v2_feature_names=use_v2_feature_names)
        export_module = serving_modules.SentencePrediction(params=params,
                                                           model=model)
        functions = export_module.get_inference_signatures({
            "serve":
            "serving_default",
            "serve_examples":
            "serving_examples"
        })
        self.assertSameElements(functions.keys(),
                                ["serving_default", "serving_examples"])
        dummy_ids = tf.ones((10, 10), dtype=tf.int32)
        outputs = functions["serving_default"](dummy_ids)
        self.assertEqual(outputs["outputs"].shape, (10, 2))

        params = serving_modules.SentencePrediction.Params(
            inputs_only=False,
            parse_sequence_length=10,
            use_v2_feature_names=use_v2_feature_names)
        export_module = serving_modules.SentencePrediction(params=params,
                                                           model=model)
        functions = export_module.get_inference_signatures({
            "serve":
            "serving_default",
            "serve_examples":
            "serving_examples"
        })
        outputs = functions["serving_default"](input_word_ids=dummy_ids,
                                               input_mask=dummy_ids,
                                               input_type_ids=dummy_ids)
        self.assertEqual(outputs["outputs"].shape, (10, 2))

        dummy_ids = tf.ones((10, ), dtype=tf.int32)
        examples = _create_fake_serialized_examples({
            input_word_ids_field:
            dummy_ids,
            "input_mask":
            dummy_ids,
            input_type_ids_field:
            dummy_ids
        })
        outputs = functions["serving_examples"](examples)
        self.assertEqual(outputs["outputs"].shape, (10, 2))

        with self.assertRaises(ValueError):
            _ = export_module.get_inference_signatures({"foo": None})
示例#12
0
def load_model_config_file(model_config_file: str) -> Dict[str, Any]:
    """Loads bert config json file or `encoders.EncoderConfig` in yaml file."""
    if not model_config_file:
        # model_config_file may be empty when using tf.hub.
        return {}

    try:
        encoder_config = encoders.EncoderConfig()
        encoder_config = hyperparams.override_params_dict(encoder_config,
                                                          model_config_file,
                                                          is_strict=True)
        logging.info('Load encoder_config yaml file from %s.',
                     model_config_file)
        return encoder_config.as_dict()
    except KeyError:
        pass

    logging.info('Load bert config json file from %s', model_config_file)
    with tf.io.gfile.GFile(model_config_file, 'r') as reader:
        text = reader.read()
        config = json.loads(text)

    def get_value(key1, key2):
        if key1 in config and key2 in config:
            raise ValueError('Unexpected that both %s and %s are in config.' %
                             (key1, key2))

        return config[key1] if key1 in config else config[key2]

    def get_value_or_none(key):
        return config[key] if key in config else None

    # Support both legacy bert_config attributes and the new config attributes.
    return {
        'bert': {
            'attention_dropout_rate':
            get_value('attention_dropout_rate',
                      'attention_probs_dropout_prob'),
            'dropout_rate':
            get_value('dropout_rate', 'hidden_dropout_prob'),
            'hidden_activation':
            get_value('hidden_activation', 'hidden_act'),
            'hidden_size':
            config['hidden_size'],
            'embedding_size':
            get_value_or_none('embedding_size'),
            'initializer_range':
            config['initializer_range'],
            'intermediate_size':
            config['intermediate_size'],
            'max_position_embeddings':
            config['max_position_embeddings'],
            'num_attention_heads':
            config['num_attention_heads'],
            'num_layers':
            get_value('num_layers', 'num_hidden_layers'),
            'type_vocab_size':
            config['type_vocab_size'],
            'vocab_size':
            config['vocab_size'],
        }
    }