def test_mobilebert_get_pretrained(model_name):
    with tempfile.TemporaryDirectory() as root:
        cfg, tokenizer, backbone_params_path, mlm_params_path =\
            get_pretrained_mobilebert(model_name, load_backbone=True, load_mlm=True, root=root)
        assert cfg.MODEL.vocab_size == len(tokenizer.vocab)
        mobilebert_model = MobileBertModel.from_cfg(cfg)
        mobilebert_model.load_parameters(backbone_params_path)
        mobilebert_pretain_model = MobileBertForPretrain(cfg)
        if mlm_params_path is not None:
            mobilebert_pretain_model.load_parameters(mlm_params_path)
        mobilebert_pretain_model = MobileBertForPretrain(cfg)
        mobilebert_pretain_model.backbone_model.load_parameters(backbone_params_path)
Example #2
0
def convert_tf_config(config_dict_path, vocab_size):
    """Convert the config file"""
    with open(config_dict_path, encoding='utf-8') as f:
        config_dict = json.load(f)
    assert vocab_size == config_dict['vocab_size']
    cfg = MobileBertModel.get_cfg().clone()
    cfg.defrost()
    cfg.MODEL.vocab_size = vocab_size
    cfg.MODEL.units = config_dict['hidden_size']
    cfg.MODEL.embed_size = config_dict['embedding_size']
    cfg.MODEL.inner_size = config_dict['intra_bottleneck_size']
    cfg.MODEL.hidden_size = config_dict['intermediate_size']
    cfg.MODEL.max_length = config_dict['max_position_embeddings']
    cfg.MODEL.num_heads = config_dict['num_attention_heads']
    cfg.MODEL.num_layers = config_dict['num_hidden_layers']
    cfg.MODEL.bottleneck_strategy
    cfg.MODEL.num_stacked_ffn = config_dict['num_feedforward_networks']
    cfg.MODEL.pos_embed_type = 'learned'
    cfg.MODEL.activation = config_dict['hidden_act']
    cfg.MODEL.num_token_types = config_dict['type_vocab_size']
    cfg.MODEL.hidden_dropout_prob = float(config_dict['hidden_dropout_prob'])
    cfg.MODEL.attention_dropout_prob = float(config_dict['attention_probs_dropout_prob'])
    cfg.MODEL.normalization = config_dict['normalization_type']
    cfg.MODEL.dtype = 'float32'

    if 'use_bottleneck_attention' in config_dict.keys():
        cfg.MODEL.bottleneck_strategy = 'from_bottleneck'
    elif 'key_query_shared_bottleneck' in config_dict.keys():
        cfg.MODEL.bottleneck_strategy = 'qk_sharing'
    else:
        cfg.MODEL.bottleneck_strategy = 'from_input'

    cfg.INITIALIZER.weight = ['truncnorm', 0,
                              config_dict['initializer_range']]  # TruncNorm(0, 0.02)
    cfg.INITIALIZER.bias = ['zeros']
    cfg.VERSION = 1
    cfg.freeze()
    return cfg
def test_mobilebert_model_small_cfg(compute_layout, ctx):
    with ctx:
        cfg = MobileBertModel.get_cfg()
        cfg.defrost()
        cfg.MODEL.vocab_size = 100
        cfg.MODEL.num_layers = 2
        cfg.MODEL.hidden_size = 128
        cfg.MODEL.num_heads = 2
        cfg.MODEL.compute_layout = compute_layout
        cfg.freeze()

        # Generate TN layout
        cfg_tn = cfg.clone()
        cfg_tn.defrost()
        cfg_tn.MODEL.layout = 'TN'
        cfg_tn.freeze()

        batch_size = 4
        sequence_length = 16
        num_mask = 3
        inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
        token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
        valid_length = mx.np.random.randint(3, sequence_length, (batch_size, ))
        masked_positions = mx.np.random.randint(0, 3, (batch_size, num_mask))

        mobile_bert_model = MobileBertModel.from_cfg(cfg)
        mobile_bert_model.initialize()
        mobile_bert_model.hybridize()
        mobile_bert_model_tn = MobileBertModel.from_cfg(cfg_tn)
        mobile_bert_model_tn.share_parameters(
            mobile_bert_model.collect_params())
        mobile_bert_model_tn.hybridize()
        contextual_embedding, pooled_out = mobile_bert_model(
            inputs, token_types, valid_length)
        contextual_embedding_tn, pooled_out_tn = mobile_bert_model_tn(
            inputs.T, token_types.T, valid_length)
        assert_allclose(contextual_embedding.asnumpy(),
                        np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1),
                        1E-3, 1E-3)
        assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-3,
                        1E-3)

        # Test for MobileBertForMLM
        mobile_bert_mlm_model = MobileBertForMLM(cfg)
        mobile_bert_mlm_model.initialize()
        mobile_bert_mlm_model.hybridize()
        mobile_bert_mlm_model_tn = MobileBertForMLM(cfg_tn)
        mobile_bert_mlm_model_tn.share_parameters(
            mobile_bert_mlm_model.collect_params())
        mobile_bert_model_tn.hybridize()
        contextual_embedding, pooled_out, mlm_score = mobile_bert_mlm_model(
            inputs, token_types, valid_length, masked_positions)
        contextual_embedding_tn, pooled_out_tn, mlm_score_tn =\
            mobile_bert_mlm_model_tn(inputs.T, token_types.T, valid_length, masked_positions)
        assert_allclose(contextual_embedding.asnumpy(),
                        np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1),
                        1E-3, 1E-3)
        assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-3,
                        1E-3)
        assert_allclose(mlm_score_tn.asnumpy(), mlm_score.asnumpy(), 1E-3,
                        1E-3)

        # Test for MobileBertForPretrain
        mobile_bert_pretrain_model = MobileBertForPretrain(cfg)
        mobile_bert_pretrain_model.initialize()
        mobile_bert_pretrain_model.hybridize()
        mobile_bert_pretrain_model_tn = MobileBertForPretrain(cfg_tn)
        mobile_bert_pretrain_model_tn.share_parameters(
            mobile_bert_pretrain_model.collect_params())
        mobile_bert_pretrain_model_tn.hybridize()
        contextual_embedding, pooled_out, nsp_score, mlm_score =\
            mobile_bert_pretrain_model(inputs, token_types, valid_length, masked_positions)
        contextual_embedding_tn, pooled_out_tn, nsp_score_tn, mlm_score_tn = \
            mobile_bert_pretrain_model_tn(inputs.T, token_types.T, valid_length, masked_positions)
        assert_allclose(contextual_embedding.asnumpy(),
                        np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1),
                        1E-3, 1E-3)
        assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-3,
                        1E-3)
        assert_allclose(nsp_score.asnumpy(), nsp_score_tn.asnumpy(), 1E-3,
                        1E-3)
        assert_allclose(mlm_score.asnumpy(), mlm_score_tn.asnumpy(), 1E-3,
                        1E-3)

        # Test for fp16
        if ctx.device_type == 'gpu':
            pytest.skip('MobileBERT will have nan values in FP16 mode.')
            verify_backbone_fp16(model_cls=MobileBertModel,
                                 cfg=cfg,
                                 ctx=ctx,
                                 inputs=[inputs, token_types, valid_length])
def test_mobilebert_model_small_cfg(compute_layout):
    cfg = MobileBertModel.get_cfg()
    cfg.defrost()
    cfg.MODEL.vocab_size = 100
    cfg.MODEL.num_layers = 2
    cfg.MODEL.hidden_size = 128
    cfg.MODEL.num_heads = 2
    cfg.MODEL.compute_layout = compute_layout
    cfg.freeze()

    # Generate TN layout
    cfg_tn = cfg.clone()
    cfg_tn.defrost()
    cfg_tn.MODEL.layout = 'TN'
    cfg_tn.freeze()

    batch_size = 4
    sequence_length = 16
    num_mask = 3
    inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
    token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
    valid_length = mx.np.random.randint(3, sequence_length, (batch_size, ))
    masked_positions = mx.np.random.randint(0, 3, (batch_size, num_mask))

    mobile_bert_model = MobileBertModel.from_cfg(cfg)
    mobile_bert_model.initialize()
    mobile_bert_model.hybridize()
    mobile_bert_model_tn = MobileBertModel.from_cfg(cfg_tn)
    mobile_bert_model_tn.share_parameters(mobile_bert_model.collect_params())
    mobile_bert_model_tn.hybridize()
    contextual_embedding, pooled_out = mobile_bert_model(
        inputs, token_types, valid_length)
    contextual_embedding_tn, pooled_out_tn = mobile_bert_model_tn(
        inputs.T, token_types.T, valid_length)
    assert_allclose(contextual_embedding.asnumpy(),
                    np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1), 1E-4,
                    1E-4)
    assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-4, 1E-4)

    # Test for MobileBertForMLM
    mobile_bert_mlm_model = MobileBertForMLM(cfg)
    mobile_bert_mlm_model.initialize()
    mobile_bert_mlm_model.hybridize()
    mobile_bert_mlm_model_tn = MobileBertForMLM(cfg_tn)
    mobile_bert_mlm_model_tn.share_parameters(
        mobile_bert_mlm_model.collect_params())
    mobile_bert_model_tn.hybridize()
    contextual_embedding, pooled_out, mlm_scores = mobile_bert_mlm_model(
        inputs, token_types, valid_length, masked_positions)
    contextual_embedding_tn, pooled_out_tn, mlm_scores_tn =\
        mobile_bert_mlm_model_tn(inputs.T, token_types.T, valid_length, masked_positions)
    assert_allclose(contextual_embedding.asnumpy(),
                    np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1), 1E-4,
                    1E-4)
    assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-4, 1E-4)
    assert_allclose(mlm_scores_tn.asnumpy(), mlm_scores.asnumpy(), 1E-4, 1E-4)

    # Test for MobileBertForPretrain
    mobile_bert_pretrain_model = MobileBertForPretrain(cfg)
    mobile_bert_pretrain_model.initialize()
    mobile_bert_pretrain_model.hybridize()
    mobile_bert_pretrain_model_tn = MobileBertForPretrain(cfg_tn)
    mobile_bert_pretrain_model_tn.share_parameters(
        mobile_bert_pretrain_model.collect_params())
    mobile_bert_pretrain_model_tn.hybridize()
    contextual_embedding, pooled_out, nsp_score, mlm_scores =\
        mobile_bert_pretrain_model(inputs, token_types, valid_length, masked_positions)
    contextual_embedding_tn, pooled_out_tn, nsp_score_tn, mlm_scores_tn = \
        mobile_bert_pretrain_model_tn(inputs.T, token_types.T, valid_length, masked_positions)
    assert_allclose(contextual_embedding.asnumpy(),
                    np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1), 1E-4,
                    1E-4)
    assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-4, 1E-4)
    assert_allclose(nsp_score.asnumpy(), nsp_score_tn.asnumpy(), 1E-4, 1E-4)
    assert_allclose(mlm_scores.asnumpy(), mlm_scores_tn.asnumpy(), 1E-4, 1E-4)