Example #1
0
def get_pretraining_model(model_name,
                          ctx_l,
                          max_seq_length=128,
                          hidden_dropout_prob=0.1,
                          attention_dropout_prob=0.1,
                          generator_units_scale=None,
                          generator_layers_scale=None):
    """
    A Electra Pretrain Model is built with a generator and a discriminator, in which
    the generator has the same embedding as the discriminator but different backbone.
    """
    cfg, tokenizer, _, _ = get_pretrained_electra(model_name,
                                                  load_backbone=False)
    cfg = ElectraModel.get_cfg().clone_merge(cfg)
    cfg.defrost()
    cfg.MODEL.hidden_dropout_prob = hidden_dropout_prob
    cfg.MODEL.attention_dropout_prob = attention_dropout_prob
    cfg.MODEL.max_length = max_seq_length
    # Keep the original generator size if not designated
    if generator_layers_scale:
        cfg.MODEL.generator_layers_scale = generator_layers_scale
    if generator_units_scale:
        cfg.MODEL.generator_units_scale = generator_units_scale
    cfg.freeze()

    model = ElectraForPretrain(cfg,
                               uniform_generator=False,
                               tied_generator=False,
                               tied_embeddings=True,
                               disallow_correct=False,
                               weight_initializer=TruncNorm(stdev=0.02))
    model.initialize(ctx=ctx_l)
    model.hybridize()
    return cfg, tokenizer, model
Example #2
0
def test_electra_get_pretrained(model_name, ctx):
    assert len(list_pretrained_electra()) > 0
    with tempfile.TemporaryDirectory() as root, ctx:
        cfg, tokenizer, backbone_params_path, (disc_params_path, gen_params_path) =\
            get_pretrained_electra(model_name, root=root,
                                   load_backbone=True, load_disc=True, load_gen=True)
        assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
        electra_model = ElectraModel.from_cfg(cfg)
        electra_model.load_parameters(backbone_params_path)

        electra_disc_model = ElectraDiscriminator(cfg)
        electra_disc_model.load_parameters(disc_params_path)
        electra_disc_model = ElectraDiscriminator(cfg)
        electra_disc_model.backbone_model.load_parameters(backbone_params_path)

        gen_cfg = get_generator_cfg(cfg)
        electra_gen_model = ElectraGenerator(gen_cfg)
        electra_gen_model.load_parameters(gen_params_path)
        electra_gen_model.tie_embeddings(
            electra_disc_model.backbone_model.word_embed.collect_params(),
            electra_disc_model.backbone_model.token_type_embed.collect_params(
            ),
            electra_disc_model.backbone_model.token_pos_embed.collect_params(),
            electra_disc_model.backbone_model.embed_layer_norm.collect_params(
            ))

        electra_gen_model = ElectraGenerator(cfg)
        electra_gen_model.backbone_model.load_parameters(backbone_params_path)