def test_electra_get_pretrained(model_name, ctx): assert len(list_pretrained_electra()) > 0 with tempfile.TemporaryDirectory() as root, ctx: cfg, tokenizer, backbone_params_path, (disc_params_path, gen_params_path) =\ get_pretrained_electra(model_name, root=root, load_backbone=True, load_disc=True, load_gen=True) assert cfg.MODEL.vocab_size == len(tokenizer.vocab) electra_model = ElectraModel.from_cfg(cfg) electra_model.load_parameters(backbone_params_path) electra_disc_model = ElectraDiscriminator(cfg) electra_disc_model.load_parameters(disc_params_path) electra_disc_model = ElectraDiscriminator(cfg) electra_disc_model.backbone_model.load_parameters(backbone_params_path) gen_cfg = get_generator_cfg(cfg) electra_gen_model = ElectraGenerator(gen_cfg) electra_gen_model.load_parameters(gen_params_path) electra_gen_model.tie_embeddings( electra_disc_model.backbone_model.word_embed.collect_params(), electra_disc_model.backbone_model.token_type_embed.collect_params( ), electra_disc_model.backbone_model.token_pos_embed.collect_params(), electra_disc_model.backbone_model.embed_layer_norm.collect_params( )) electra_gen_model = ElectraGenerator(cfg) electra_gen_model.backbone_model.load_parameters(backbone_params_path)
def test_list_pretrained_electra(): assert len(list_pretrained_electra()) > 0
contextual_embedding, pooled_out = electra_model( inputs, token_types, valid_length) electra_model_tn = ElectraModel.from_cfg(cfg_tn) electra_model_tn.share_parameters(electra_model.collect_params()) electra_model_tn.hybridize() contextual_embedding_tn, pooled_out_tn = electra_model_tn( inputs.T, token_types.T, valid_length) assert_allclose(contextual_embedding.asnumpy(), np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1), 1E-4, 1E-4) assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-4, 1E-4) @pytest.mark.remote_required @pytest.mark.parametrize('model_name', list_pretrained_electra()) def test_electra_get_pretrained(model_name, ctx): assert len(list_pretrained_electra()) > 0 with tempfile.TemporaryDirectory() as root, ctx: cfg, tokenizer, backbone_params_path, (disc_params_path, gen_params_path) =\ get_pretrained_electra(model_name, root=root, load_backbone=True, load_disc=True, load_gen=True) assert cfg.MODEL.vocab_size == len(tokenizer.vocab) electra_model = ElectraModel.from_cfg(cfg) electra_model.load_parameters(backbone_params_path) electra_disc_model = ElectraDiscriminator(cfg) electra_disc_model.load_parameters(disc_params_path) electra_disc_model = ElectraDiscriminator(cfg) electra_disc_model.backbone_model.load_parameters(backbone_params_path)