Example #1
0
def get_pretraining_model(model_name, ctx_l, max_seq_length=512):
    cfg, tokenizer, _, _ = get_pretrained_bert(
        model_name, load_backbone=False, load_mlm=False)
    cfg = BertModel.get_cfg().clone_merge(cfg)
    cfg.defrost()
    cfg.MODEL.max_length = max_seq_length
    cfg.freeze()
    model = BertForPretrain(cfg)
    model.initialize(ctx=ctx_l)
    model.hybridize()
    return cfg, tokenizer, model
Example #2
0
def test_bert_small_cfg(compute_layout, ctx):
    with ctx:
        cfg = BertModel.get_cfg()
        cfg.defrost()
        cfg.MODEL.vocab_size = 100
        cfg.MODEL.units = 12 * 4
        cfg.MODEL.hidden_size = 64
        cfg.MODEL.num_layers = 2
        cfg.MODEL.num_heads = 2
        cfg.MODEL.compute_layout = compute_layout
        cfg.freeze()

        # Generate TN layout
        cfg_tn = cfg.clone()
        cfg_tn.defrost()
        cfg_tn.MODEL.layout = 'TN'
        cfg_tn.freeze()

        # Sample data
        batch_size = 4
        sequence_length = 8
        num_mask = 3
        inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
        token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
        valid_length = mx.np.random.randint(3, sequence_length, (batch_size,))
        masked_positions = mx.np.random.randint(0, 3, (batch_size, num_mask))

        # Test for BertModel
        bert_model = BertModel.from_cfg(cfg)
        bert_model.initialize()
        bert_model.hybridize()
        contextual_embedding, pooled_out = bert_model(inputs, token_types, valid_length)
        bert_model_tn = BertModel.from_cfg(cfg_tn)
        bert_model_tn.share_parameters(bert_model.collect_params())
        bert_model_tn.hybridize()
        contextual_embedding_tn, pooled_out_tn = bert_model_tn(inputs.T, token_types.T, valid_length)
        assert_allclose(contextual_embedding.asnumpy(),
                        mx.np.swapaxes(contextual_embedding_tn, 0, 1).asnumpy(),
                        1E-4, 1E-4)
        assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-4, 1E-4)

        # Test for BertForMLM
        bert_mlm_model = BertForMLM(cfg)
        bert_mlm_model.initialize()
        bert_mlm_model.hybridize()
        contextual_embedding, pooled_out, mlm_score = bert_mlm_model(inputs, token_types,
                                                                     valid_length, masked_positions)
        bert_mlm_model_tn = BertForMLM(cfg_tn)
        bert_mlm_model_tn.share_parameters(bert_mlm_model.collect_params())
        bert_mlm_model_tn.hybridize()
        contextual_embedding_tn, pooled_out_tn, mlm_score_tn =\
            bert_mlm_model_tn(inputs.T, token_types.T, valid_length, masked_positions)
        assert_allclose(contextual_embedding.asnumpy(),
                        mx.np.swapaxes(contextual_embedding_tn, 0, 1).asnumpy(),
                        1E-4, 1E-4)
        assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-3, 1E-3)
        assert_allclose(mlm_score.asnumpy(), mlm_score_tn.asnumpy(), 1E-3, 1E-3)

        # Test for BertForPretrain
        bert_pretrain_model = BertForPretrain(cfg)
        bert_pretrain_model.initialize()
        bert_pretrain_model.hybridize()
        contextual_embedding, pooled_out, nsp_score, mlm_scores =\
            bert_pretrain_model(inputs, token_types, valid_length, masked_positions)
        bert_pretrain_model_tn = BertForPretrain(cfg_tn)
        bert_pretrain_model_tn.share_parameters(bert_pretrain_model.collect_params())
        bert_pretrain_model_tn.hybridize()
        contextual_embedding_tn, pooled_out_tn, nsp_score_tn, mlm_scores_tn = \
            bert_pretrain_model_tn(inputs.T, token_types.T, valid_length, masked_positions)
        assert_allclose(contextual_embedding.asnumpy(),
                        mx.np.swapaxes(contextual_embedding_tn, 0, 1).asnumpy(),
                        1E-3, 1E-3)
        assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-3, 1E-3)
        assert_allclose(nsp_score.asnumpy(), nsp_score_tn.asnumpy(), 1E-3, 1E-3)
        assert_allclose(mlm_score.asnumpy(), mlm_score_tn.asnumpy(), 1E-3, 1E-3)

        # Test BertModel FP16
        device_type = ctx.device_type
        if device_type == 'gpu':
            verify_backbone_fp16(model_cls=BertModel, cfg=cfg, ctx=ctx,
                                 inputs=[inputs, token_types, valid_length])