Exemplo n.º 1
0
def test_albert_get_pretrained(model_name):
    assert len(list_pretrained_albert()) > 0
    with tempfile.TemporaryDirectory() as root:
        cfg, tokenizer, backbone_params_path, mlm_params_path =\
            get_pretrained_albert(model_name, load_backbone=True, load_mlm=True, root=root)
        assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
        albert_model = AlbertModel.from_cfg(cfg)
        albert_model.load_parameters(backbone_params_path)
        albert_mlm_model = AlbertForMLM(cfg)
        if mlm_params_path is not None:
            albert_mlm_model.load_parameters(mlm_params_path)
        # Just load the backbone
        albert_mlm_model = AlbertForMLM(cfg)
        albert_mlm_model.backbone_model.load_parameters(backbone_params_path)
Exemplo n.º 2
0
def test_albert_for_mlm_model(compute_layout):
    batch_size = 3
    cfg = get_test_cfg()
    cfg.defrost()
    cfg.MODEL.compute_layout = compute_layout
    cfg.freeze()
    albert_mlm_model = AlbertForMLM(backbone_cfg=cfg)
    albert_mlm_model.initialize()
    albert_mlm_model.hybridize()
    cfg_tn = cfg.clone()
    cfg_tn.defrost()
    cfg_tn.MODEL.layout = 'TN'
    cfg_tn.freeze()
    albert_mlm_tn_model = AlbertForMLM(backbone_cfg=cfg_tn)
    albert_mlm_tn_model.share_parameters(albert_mlm_model.collect_params())
    albert_mlm_tn_model.hybridize()

    num_mask = 16
    seq_length = 64
    inputs = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length))
    token_types = mx.np.random.randint(0, cfg.MODEL.num_token_types, (batch_size, seq_length))
    valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,))
    masked_positions = mx.np.random.randint(0, seq_length // 2, (batch_size, num_mask))
    contextual_embeddings, pooled_out, mlm_scores = albert_mlm_model(inputs, token_types, valid_length, masked_positions)
    contextual_embeddings_tn, pooled_out_tn, mlm_scores_tn = albert_mlm_tn_model(inputs.T, token_types.T, valid_length, masked_positions)
    assert_allclose(np.swapaxes(contextual_embeddings_tn.asnumpy(), 0, 1),
                    contextual_embeddings.asnumpy(), 1E-4, 1E-4)
    assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-4, 1E-4)
    assert_allclose(mlm_scores_tn.asnumpy(), mlm_scores.asnumpy(), 1E-4, 1E-4)
    assert mlm_scores.shape == (batch_size, num_mask, cfg.MODEL.vocab_size)