def get_pretraining_model(model_name, ctx_l, max_seq_length=128, hidden_dropout_prob=0.1, attention_dropout_prob=0.1, generator_units_scale=None, generator_layers_scale=None): """ A Electra Pretrain Model is built with a generator and a discriminator, in which the generator has the same embedding as the discriminator but different backbone. """ cfg, tokenizer, _, _ = get_pretrained_electra(model_name, load_backbone=False) cfg = ElectraModel.get_cfg().clone_merge(cfg) cfg.defrost() cfg.MODEL.hidden_dropout_prob = hidden_dropout_prob cfg.MODEL.attention_dropout_prob = attention_dropout_prob cfg.MODEL.max_length = max_seq_length # Keep the original generator size if not designated if generator_layers_scale: cfg.MODEL.generator_layers_scale = generator_layers_scale if generator_units_scale: cfg.MODEL.generator_units_scale = generator_units_scale cfg.freeze() model = ElectraForPretrain(cfg, uniform_generator=False, tied_generator=False, tied_embeddings=True, disallow_correct=False, weight_initializer=TruncNorm(stdev=0.02)) model.initialize(ctx=ctx_l) model.hybridize() return cfg, tokenizer, model
def test_electra_get_pretrained(model_name, ctx): assert len(list_pretrained_electra()) > 0 with tempfile.TemporaryDirectory() as root, ctx: cfg, tokenizer, backbone_params_path, (disc_params_path, gen_params_path) =\ get_pretrained_electra(model_name, root=root, load_backbone=True, load_disc=True, load_gen=True) assert cfg.MODEL.vocab_size == len(tokenizer.vocab) electra_model = ElectraModel.from_cfg(cfg) electra_model.load_parameters(backbone_params_path) electra_disc_model = ElectraDiscriminator(cfg) electra_disc_model.load_parameters(disc_params_path) electra_disc_model = ElectraDiscriminator(cfg) electra_disc_model.backbone_model.load_parameters(backbone_params_path) gen_cfg = get_generator_cfg(cfg) electra_gen_model = ElectraGenerator(gen_cfg) electra_gen_model.load_parameters(gen_params_path) electra_gen_model.tie_embeddings( electra_disc_model.backbone_model.word_embed.collect_params(), electra_disc_model.backbone_model.token_type_embed.collect_params( ), electra_disc_model.backbone_model.token_pos_embed.collect_params(), electra_disc_model.backbone_model.embed_layer_norm.collect_params( )) electra_gen_model = ElectraGenerator(cfg) electra_gen_model.backbone_model.load_parameters(backbone_params_path)