Example #1
0
def _build_pretrainer(
        config: electra.ElectraPretrainerConfig) -> models.ElectraPretrainer:
    """Instantiates ElectraPretrainer from the config."""
    generator_encoder_cfg = config.generator_encoder
    discriminator_encoder_cfg = config.discriminator_encoder
    # Copy discriminator's embeddings to generator for easier model serialization.
    discriminator_network = encoders.build_encoder(discriminator_encoder_cfg)
    if config.tie_embeddings:
        embedding_layer = discriminator_network.get_embedding_layer()
        generator_network = encoders.build_encoder(
            generator_encoder_cfg, embedding_layer=embedding_layer)
    else:
        generator_network = encoders.build_encoder(generator_encoder_cfg)

    generator_encoder_cfg = generator_encoder_cfg.get()
    return models.ElectraPretrainer(
        generator_network=generator_network,
        discriminator_network=discriminator_network,
        vocab_size=generator_encoder_cfg.vocab_size,
        num_classes=config.num_classes,
        sequence_length=config.sequence_length,
        num_token_predictions=config.num_masked_tokens,
        mlm_activation=tf_utils.get_activation(
            generator_encoder_cfg.hidden_activation),
        mlm_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=generator_encoder_cfg.initializer_range),
        classification_heads=[
            layers.ClassificationHead(**cfg.as_dict())
            for cfg in config.cls_heads
        ],
        disallow_correct=config.disallow_correct)
Example #2
0
def instantiate_from_cfg(config: BertPretrainerConfig,
                         encoder_network: Optional[tf.keras.Model] = None):
    """Instantiates a BertPretrainer from the config."""
    encoder_cfg = config.encoder
    if encoder_network is None:
        encoder_network = networks.TransformerEncoder(
            vocab_size=encoder_cfg.vocab_size,
            hidden_size=encoder_cfg.hidden_size,
            num_layers=encoder_cfg.num_layers,
            num_attention_heads=encoder_cfg.num_attention_heads,
            intermediate_size=encoder_cfg.intermediate_size,
            activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
            dropout_rate=encoder_cfg.dropout_rate,
            attention_dropout_rate=encoder_cfg.attention_dropout_rate,
            max_sequence_length=encoder_cfg.max_position_embeddings,
            type_vocab_size=encoder_cfg.type_vocab_size,
            initializer=tf.keras.initializers.TruncatedNormal(
                stddev=encoder_cfg.initializer_range))
    if config.cls_heads:
        classification_heads = [
            layers.ClassificationHead(**cfg.as_dict())
            for cfg in config.cls_heads
        ]
    else:
        classification_heads = []
    return bert_pretrainer.BertPretrainerV2(
        config.num_masked_tokens,
        mlm_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=encoder_cfg.initializer_range),
        encoder_network=encoder_network,
        classification_heads=classification_heads)
Example #3
0
 def build_model(self, params=None):
     config = params or self.task_config.model
     encoder_cfg = config.encoder
     encoder_network = encoders.build_encoder(encoder_cfg)
     cls_heads = [
         layers.ClassificationHead(**cfg.as_dict())
         for cfg in config.cls_heads
     ] if config.cls_heads else []
     return models.BertPretrainerV2(
         mlm_activation=tf_utils.get_activation(config.mlm_activation),
         mlm_initializer=tf.keras.initializers.TruncatedNormal(
             stddev=config.mlm_initializer_range),
         encoder_network=encoder_network,
         classification_heads=cls_heads)
Example #4
0
def instantiate_classification_heads_from_cfgs(
        cls_head_configs: List[ClsHeadConfig]
) -> List[layers.ClassificationHead]:
    return [
        layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs
    ] if cls_head_configs else []