def _build_pretrainer( config: electra.ElectraPretrainerConfig) -> models.ElectraPretrainer: """Instantiates ElectraPretrainer from the config.""" generator_encoder_cfg = config.generator_encoder discriminator_encoder_cfg = config.discriminator_encoder # Copy discriminator's embeddings to generator for easier model serialization. discriminator_network = encoders.build_encoder(discriminator_encoder_cfg) if config.tie_embeddings: embedding_layer = discriminator_network.get_embedding_layer() generator_network = encoders.build_encoder( generator_encoder_cfg, embedding_layer=embedding_layer) else: generator_network = encoders.build_encoder(generator_encoder_cfg) generator_encoder_cfg = generator_encoder_cfg.get() return models.ElectraPretrainer( generator_network=generator_network, discriminator_network=discriminator_network, vocab_size=generator_encoder_cfg.vocab_size, num_classes=config.num_classes, sequence_length=config.sequence_length, num_token_predictions=config.num_masked_tokens, mlm_activation=tf_utils.get_activation( generator_encoder_cfg.hidden_activation), mlm_initializer=tf.keras.initializers.TruncatedNormal( stddev=generator_encoder_cfg.initializer_range), classification_heads=[ layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads ], disallow_correct=config.disallow_correct)
def instantiate_from_cfg(config: BertPretrainerConfig, encoder_network: Optional[tf.keras.Model] = None): """Instantiates a BertPretrainer from the config.""" encoder_cfg = config.encoder if encoder_network is None: encoder_network = networks.TransformerEncoder( vocab_size=encoder_cfg.vocab_size, hidden_size=encoder_cfg.hidden_size, num_layers=encoder_cfg.num_layers, num_attention_heads=encoder_cfg.num_attention_heads, intermediate_size=encoder_cfg.intermediate_size, activation=tf_utils.get_activation(encoder_cfg.hidden_activation), dropout_rate=encoder_cfg.dropout_rate, attention_dropout_rate=encoder_cfg.attention_dropout_rate, max_sequence_length=encoder_cfg.max_position_embeddings, type_vocab_size=encoder_cfg.type_vocab_size, initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range)) if config.cls_heads: classification_heads = [ layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads ] else: classification_heads = [] return bert_pretrainer.BertPretrainerV2( config.num_masked_tokens, mlm_initializer=tf.keras.initializers.TruncatedNormal( stddev=encoder_cfg.initializer_range), encoder_network=encoder_network, classification_heads=classification_heads)
def build_model(self, params=None): config = params or self.task_config.model encoder_cfg = config.encoder encoder_network = encoders.build_encoder(encoder_cfg) cls_heads = [ layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads ] if config.cls_heads else [] return models.BertPretrainerV2( mlm_activation=tf_utils.get_activation(config.mlm_activation), mlm_initializer=tf.keras.initializers.TruncatedNormal( stddev=config.mlm_initializer_range), encoder_network=encoder_network, classification_heads=cls_heads)
def instantiate_classification_heads_from_cfgs( cls_head_configs: List[ClsHeadConfig] ) -> List[layers.ClassificationHead]: return [ layers.ClassificationHead(**cfg.as_dict()) for cfg in cls_head_configs ] if cls_head_configs else []