def test_albert_get_pretrained(model_name):
    assert len(list_pretrained_albert()) > 0
    with tempfile.TemporaryDirectory() as root:
        cfg, tokenizer, backbone_params_path, mlm_params_path =\
            get_pretrained_albert(model_name, load_backbone=True, load_mlm=True, root=root)
        assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
        albert_model = AlbertModel.from_cfg(cfg)
        albert_model.load_parameters(backbone_params_path)
        albert_mlm_model = AlbertForMLM(cfg)
        if mlm_params_path is not None:
            albert_mlm_model.load_parameters(mlm_params_path)
        # Just load the backbone
        albert_mlm_model = AlbertForMLM(cfg)
        albert_mlm_model.backbone_model.load_parameters(backbone_params_path)
def test_list_pretrained_albert():
    assert len(list_pretrained_albert()) > 0
    assert_allclose(np.swapaxes(contextual_embeddings_tn.asnumpy(), 0, 1),
                    contextual_embeddings.asnumpy(), 1E-4, 1E-4)
    assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-4, 1E-4)
    assert_allclose(sop_score.asnumpy(), sop_score_tn.asnumpy(), 1E-4, 1E-4)
    assert_allclose(mlm_scores.asnumpy(), mlm_scores_tn.asnumpy(), 1E-4, 1E-4)
    assert mlm_scores.shape == (batch_size, num_mask, cfg.MODEL.vocab_size)
    assert sop_score.shape == (batch_size, 2)


def test_list_pretrained_albert():
    assert len(list_pretrained_albert()) > 0


@pytest.mark.slow
@pytest.mark.remote_required
@pytest.mark.parametrize('model_name', list_pretrained_albert())
def test_albert_get_pretrained(model_name):
    assert len(list_pretrained_albert()) > 0
    with tempfile.TemporaryDirectory() as root:
        cfg, tokenizer, backbone_params_path, mlm_params_path =\
            get_pretrained_albert(model_name, load_backbone=True, load_mlm=True, root=root)
        assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
        albert_model = AlbertModel.from_cfg(cfg)
        albert_model.load_parameters(backbone_params_path)
        albert_mlm_model = AlbertForMLM(cfg)
        if mlm_params_path is not None:
            albert_mlm_model.load_parameters(mlm_params_path)
        # Just load the backbone
        albert_mlm_model = AlbertForMLM(cfg)
        albert_mlm_model.backbone_model.load_parameters(backbone_params_path)