Exemple #1
0
def test_bart_cfg(cfg_key, ctx):
    cfg = BartModel.get_cfg(cfg_key)
    cfg.defrost()
    cfg.MODEL.vocab_size = 32
    cfg.freeze()

    cfg_tn = cfg.clone()
    cfg_tn.defrost()
    cfg_tn.MODEL.layout = 'TN'
    cfg_tn.freeze()

    batch_size = 4
    src_length = 32
    tgt_length = 16

    with ctx:
        src_data = mx.np.random.randint(0,
                                        cfg.MODEL.vocab_size,
                                        (batch_size, src_length),
                                        dtype=np.int32)
        src_valid_length = mx.np.random.randint(src_length // 2,
                                                src_length, (batch_size, ),
                                                dtype=np.int32)
        tgt_data = mx.np.random.randint(0,
                                        cfg.MODEL.vocab_size,
                                        (batch_size, tgt_length),
                                        dtype=np.int32)
        tgt_valid_length = mx.np.random.randint(tgt_length // 2,
                                                tgt_length, (batch_size, ),
                                                dtype=np.int32)
        model = BartModel.from_cfg(cfg, extract_feature=True)
        model.initialize()
        model.hybridize()

        contextual_embedding, pooled_output = model(src_data, src_valid_length,
                                                    tgt_data, tgt_valid_length)
        model_tn = BartModel.from_cfg(cfg_tn, extract_feature=True)
        model_tn.share_parameters(model.collect_params())
        model_tn.hybridize()
        contextual_embedding_tn, pooled_out_tn = model_tn(
            src_data.T, src_valid_length, tgt_data.T, tgt_valid_length)
        npt.assert_allclose(
            contextual_embedding.asnumpy(),
            np.transpose(contextual_embedding_tn.asnumpy(), (1, 0, 2)), 5E-3,
            5E-3)
        npt.assert_allclose(pooled_out_tn.asnumpy(), pooled_output.asnumpy(),
                            5E-3, 5E-3)
        mx.npx.waitall()

        # Verify Float16
        if ctx.device_type == 'gpu':
            verify_backbone_fp16(model_cls=BartModel,
                                 cfg=cfg,
                                 ctx=ctx,
                                 inputs=[
                                     src_data, src_valid_length, tgt_data,
                                     tgt_valid_length
                                 ])
def test_robert_small_config(compute_layout, ctx):
    with ctx:
        cfg = RobertaModel.get_cfg()
        cfg.defrost()
        cfg.MODEL.vocab_size = 1000
        cfg.MODEL.num_layers = 2
        cfg.MODEL.hidden_size = 128
        cfg.MODEL.num_heads = 2
        cfg.MODEL.compute_layout = compute_layout
        cfg.freeze()

        # Generate TN layout
        cfg_tn = cfg.clone()
        cfg_tn.defrost()
        cfg_tn.MODEL.layout = 'TN'
        cfg_tn.freeze()

        batch_size = 4
        sequence_length = 16
        num_mask = 3
        inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
        valid_length = mx.np.random.randint(3, sequence_length, (batch_size,))
        masked_positions = mx.np.random.randint(0, 3, (batch_size, num_mask))

        roberta_model = RobertaModel.from_cfg(cfg)
        roberta_model.initialize()
        roberta_model.hybridize()
        contextual_embeddings, pooled_out = roberta_model(inputs, valid_length)
        roberta_model_tn = RobertaModel.from_cfg(cfg_tn)
        roberta_model_tn.share_parameters(roberta_model.collect_params())
        roberta_model_tn.hybridize()
        contextual_embeddings_tn, pooled_out_tn = roberta_model_tn(inputs.T, valid_length)
        assert_allclose(np.swapaxes(contextual_embeddings_tn.asnumpy(), 0, 1),
                        contextual_embeddings.asnumpy(), 1E-3, 1E-3)
        assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-3, 1E-3)

        # Test for RobertaForMLM
        roberta_mlm_model = RobertaForMLM(cfg)
        roberta_mlm_model.initialize()
        roberta_mlm_model.hybridize()
        contextual_embedding, pooled_out, mlm_score = roberta_mlm_model(inputs, valid_length,
                                                                         masked_positions)
        roberta_mlm_model_tn = RobertaForMLM(cfg_tn)
        roberta_mlm_model_tn.share_parameters(roberta_mlm_model.collect_params())
        roberta_mlm_model_tn.hybridize()
        contextual_embedding_tn, pooled_out_tn, mlm_score_tn =\
            roberta_mlm_model_tn(inputs.T, valid_length.T, masked_positions)
        assert_allclose(np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1),
                        contextual_embedding.asnumpy(), 1E-3, 1E-3)
        assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-3, 1E-3)
        assert_allclose(mlm_score_tn.asnumpy(), mlm_score.asnumpy(), 1E-3, 1E-3)

        # Test for fp16
        if ctx.device_type == 'gpu':
            verify_backbone_fp16(model_cls=RobertaModel, cfg=cfg, ctx=ctx,
                                 inputs=[inputs, valid_length])
Exemple #3
0
def test_transformer_fp16_amp(enc_pre_norm, dec_pre_norm,
                              enc_units, dec_units,
                              enc_num_layers, dec_num_layers,
                              enc_recurrent, dec_recurrent, tie_weights,
                              layout, ctx):
    if ctx.device_type != 'gpu':
        pytest.skip('Only test amp when running on GPU.')
    # Generate configuration for testing
    cfg = TransformerModel.get_cfg()
    cfg.defrost()
    cfg.MODEL.src_vocab_size = 32
    cfg.MODEL.tgt_vocab_size = 32
    cfg.MODEL.max_src_length = 20
    cfg.MODEL.max_tgt_length = 15
    cfg.MODEL.tie_weights = tie_weights
    cfg.MODEL.layout = layout

    # Encoder config
    cfg.MODEL.ENCODER.pre_norm = enc_pre_norm
    cfg.MODEL.ENCODER.units = enc_units
    cfg.MODEL.ENCODER.num_layers = enc_num_layers
    cfg.MODEL.ENCODER.recurrent = enc_recurrent

    # Decoder config
    cfg.MODEL.DECODER.pre_norm = dec_pre_norm
    cfg.MODEL.DECODER.units = dec_units
    cfg.MODEL.DECODER.num_layers = dec_num_layers
    cfg.MODEL.DECODER.recurrent = dec_recurrent
    cfg.freeze()

    batch_size = 4
    seq_length = 16
    with ctx:
        if layout == 'NT':
            src_data = mx.np.random.randint(0, cfg.MODEL.src_vocab_size,
                                            (batch_size, seq_length), dtype=np.int32)
            src_valid_length = mx.np.random.randint(seq_length // 2, seq_length,
                                                    (batch_size,), dtype=np.int32)
            tgt_data = mx.np.random.randint(0, cfg.MODEL.tgt_vocab_size,
                                            (batch_size, seq_length), dtype=np.int32)
            tgt_valid_length = mx.np.random.randint(seq_length // 2, seq_length,
                                                    (batch_size,), dtype=np.int32)
        elif layout == 'TN':
            src_data = mx.np.random.randint(0, cfg.MODEL.src_vocab_size,
                                            (seq_length, batch_size), dtype=np.int32)
            src_valid_length = mx.np.random.randint(seq_length // 2, seq_length,
                                                    (batch_size,), dtype=np.int32)
            tgt_data = mx.np.random.randint(0, cfg.MODEL.tgt_vocab_size,
                                            (seq_length, batch_size), dtype=np.int32)
            tgt_valid_length = mx.np.random.randint(seq_length // 2, seq_length,
                                                    (batch_size,), dtype=np.int32)
        else:
            raise NotImplementedError
        verify_backbone_fp16(TransformerModel, cfg, ctx,
                             inputs=[src_data, src_valid_length, tgt_data, tgt_valid_length])
def test_electra_model(compute_layout, ctx):
    with ctx:
        cfg = get_test_cfg()
        cfg.defrost()
        cfg.MODEL.compute_layout = compute_layout
        cfg.freeze()

        # Generate TN layout
        cfg_tn = cfg.clone()
        cfg_tn.defrost()
        cfg_tn.MODEL.layout = 'TN'
        cfg_tn.freeze()

        # Sample data
        batch_size = 4
        sequence_length = 16
        num_mask = 3
        inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
        token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
        valid_length = mx.np.random.randint(3, sequence_length, (batch_size, ))
        masked_positions = mx.np.random.randint(0, 3, (batch_size, num_mask))

        electra_model = ElectraModel.from_cfg(cfg)
        electra_model.initialize()
        electra_model.hybridize()
        contextual_embedding, pooled_out = electra_model(
            inputs, token_types, valid_length)

        electra_model_tn = ElectraModel.from_cfg(cfg_tn)
        electra_model_tn.share_parameters(electra_model.collect_params())
        electra_model_tn.hybridize()
        contextual_embedding_tn, pooled_out_tn = electra_model_tn(
            inputs.T, token_types.T, valid_length)
        assert_allclose(contextual_embedding.asnumpy(),
                        np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1),
                        1E-4, 1E-4)
        assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-4,
                        1E-4)

        # Verify Float16
        if ctx.device_type == 'gpu':
            verify_backbone_fp16(model_cls=ElectraModel,
                                 cfg=cfg,
                                 ctx=ctx,
                                 inputs=[inputs, token_types, valid_length])
Exemple #5
0
def test_bert_small_cfg(compute_layout, ctx):
    with ctx:
        cfg = BertModel.get_cfg()
        cfg.defrost()
        cfg.MODEL.vocab_size = 100
        cfg.MODEL.units = 12 * 4
        cfg.MODEL.hidden_size = 64
        cfg.MODEL.num_layers = 2
        cfg.MODEL.num_heads = 2
        cfg.MODEL.compute_layout = compute_layout
        cfg.freeze()

        # Generate TN layout
        cfg_tn = cfg.clone()
        cfg_tn.defrost()
        cfg_tn.MODEL.layout = 'TN'
        cfg_tn.freeze()

        # Sample data
        batch_size = 4
        sequence_length = 8
        num_mask = 3
        inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
        token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
        valid_length = mx.np.random.randint(3, sequence_length, (batch_size,))
        masked_positions = mx.np.random.randint(0, 3, (batch_size, num_mask))

        # Test for BertModel
        bert_model = BertModel.from_cfg(cfg)
        bert_model.initialize()
        bert_model.hybridize()
        contextual_embedding, pooled_out = bert_model(inputs, token_types, valid_length)
        bert_model_tn = BertModel.from_cfg(cfg_tn)
        bert_model_tn.share_parameters(bert_model.collect_params())
        bert_model_tn.hybridize()
        contextual_embedding_tn, pooled_out_tn = bert_model_tn(inputs.T, token_types.T, valid_length)
        assert_allclose(contextual_embedding.asnumpy(),
                        mx.np.swapaxes(contextual_embedding_tn, 0, 1).asnumpy(),
                        1E-4, 1E-4)
        assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-4, 1E-4)

        # Test for BertForMLM
        bert_mlm_model = BertForMLM(cfg)
        bert_mlm_model.initialize()
        bert_mlm_model.hybridize()
        contextual_embedding, pooled_out, mlm_score = bert_mlm_model(inputs, token_types,
                                                                     valid_length, masked_positions)
        bert_mlm_model_tn = BertForMLM(cfg_tn)
        bert_mlm_model_tn.share_parameters(bert_mlm_model.collect_params())
        bert_mlm_model_tn.hybridize()
        contextual_embedding_tn, pooled_out_tn, mlm_score_tn =\
            bert_mlm_model_tn(inputs.T, token_types.T, valid_length, masked_positions)
        assert_allclose(contextual_embedding.asnumpy(),
                        mx.np.swapaxes(contextual_embedding_tn, 0, 1).asnumpy(),
                        1E-4, 1E-4)
        assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-3, 1E-3)
        assert_allclose(mlm_score.asnumpy(), mlm_score_tn.asnumpy(), 1E-3, 1E-3)

        # Test for BertForPretrain
        bert_pretrain_model = BertForPretrain(cfg)
        bert_pretrain_model.initialize()
        bert_pretrain_model.hybridize()
        contextual_embedding, pooled_out, nsp_score, mlm_scores =\
            bert_pretrain_model(inputs, token_types, valid_length, masked_positions)
        bert_pretrain_model_tn = BertForPretrain(cfg_tn)
        bert_pretrain_model_tn.share_parameters(bert_pretrain_model.collect_params())
        bert_pretrain_model_tn.hybridize()
        contextual_embedding_tn, pooled_out_tn, nsp_score_tn, mlm_scores_tn = \
            bert_pretrain_model_tn(inputs.T, token_types.T, valid_length, masked_positions)
        assert_allclose(contextual_embedding.asnumpy(),
                        mx.np.swapaxes(contextual_embedding_tn, 0, 1).asnumpy(),
                        1E-3, 1E-3)
        assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-3, 1E-3)
        assert_allclose(nsp_score.asnumpy(), nsp_score_tn.asnumpy(), 1E-3, 1E-3)
        assert_allclose(mlm_score.asnumpy(), mlm_score_tn.asnumpy(), 1E-3, 1E-3)

        # Test BertModel FP16
        device_type = ctx.device_type
        if device_type == 'gpu':
            verify_backbone_fp16(model_cls=BertModel, cfg=cfg, ctx=ctx,
                                 inputs=[inputs, token_types, valid_length])
def test_mobilebert_model_small_cfg(compute_layout, ctx):
    with ctx:
        cfg = MobileBertModel.get_cfg()
        cfg.defrost()
        cfg.MODEL.vocab_size = 100
        cfg.MODEL.num_layers = 2
        cfg.MODEL.hidden_size = 128
        cfg.MODEL.num_heads = 2
        cfg.MODEL.compute_layout = compute_layout
        cfg.freeze()

        # Generate TN layout
        cfg_tn = cfg.clone()
        cfg_tn.defrost()
        cfg_tn.MODEL.layout = 'TN'
        cfg_tn.freeze()

        batch_size = 4
        sequence_length = 16
        num_mask = 3
        inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length))
        token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length))
        valid_length = mx.np.random.randint(3, sequence_length, (batch_size, ))
        masked_positions = mx.np.random.randint(0, 3, (batch_size, num_mask))

        mobile_bert_model = MobileBertModel.from_cfg(cfg)
        mobile_bert_model.initialize()
        mobile_bert_model.hybridize()
        mobile_bert_model_tn = MobileBertModel.from_cfg(cfg_tn)
        mobile_bert_model_tn.share_parameters(
            mobile_bert_model.collect_params())
        mobile_bert_model_tn.hybridize()
        contextual_embedding, pooled_out = mobile_bert_model(
            inputs, token_types, valid_length)
        contextual_embedding_tn, pooled_out_tn = mobile_bert_model_tn(
            inputs.T, token_types.T, valid_length)
        assert_allclose(contextual_embedding.asnumpy(),
                        np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1),
                        1E-3, 1E-3)
        assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-3,
                        1E-3)

        # Test for MobileBertForMLM
        mobile_bert_mlm_model = MobileBertForMLM(cfg)
        mobile_bert_mlm_model.initialize()
        mobile_bert_mlm_model.hybridize()
        mobile_bert_mlm_model_tn = MobileBertForMLM(cfg_tn)
        mobile_bert_mlm_model_tn.share_parameters(
            mobile_bert_mlm_model.collect_params())
        mobile_bert_model_tn.hybridize()
        contextual_embedding, pooled_out, mlm_score = mobile_bert_mlm_model(
            inputs, token_types, valid_length, masked_positions)
        contextual_embedding_tn, pooled_out_tn, mlm_score_tn =\
            mobile_bert_mlm_model_tn(inputs.T, token_types.T, valid_length, masked_positions)
        assert_allclose(contextual_embedding.asnumpy(),
                        np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1),
                        1E-3, 1E-3)
        assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-3,
                        1E-3)
        assert_allclose(mlm_score_tn.asnumpy(), mlm_score.asnumpy(), 1E-3,
                        1E-3)

        # Test for MobileBertForPretrain
        mobile_bert_pretrain_model = MobileBertForPretrain(cfg)
        mobile_bert_pretrain_model.initialize()
        mobile_bert_pretrain_model.hybridize()
        mobile_bert_pretrain_model_tn = MobileBertForPretrain(cfg_tn)
        mobile_bert_pretrain_model_tn.share_parameters(
            mobile_bert_pretrain_model.collect_params())
        mobile_bert_pretrain_model_tn.hybridize()
        contextual_embedding, pooled_out, nsp_score, mlm_score =\
            mobile_bert_pretrain_model(inputs, token_types, valid_length, masked_positions)
        contextual_embedding_tn, pooled_out_tn, nsp_score_tn, mlm_score_tn = \
            mobile_bert_pretrain_model_tn(inputs.T, token_types.T, valid_length, masked_positions)
        assert_allclose(contextual_embedding.asnumpy(),
                        np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1),
                        1E-3, 1E-3)
        assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-3,
                        1E-3)
        assert_allclose(nsp_score.asnumpy(), nsp_score_tn.asnumpy(), 1E-3,
                        1E-3)
        assert_allclose(mlm_score.asnumpy(), mlm_score_tn.asnumpy(), 1E-3,
                        1E-3)

        # Test for fp16
        if ctx.device_type == 'gpu':
            pytest.skip('MobileBERT will have nan values in FP16 mode.')
            verify_backbone_fp16(model_cls=MobileBertModel,
                                 cfg=cfg,
                                 ctx=ctx,
                                 inputs=[inputs, token_types, valid_length])
Exemple #7
0
def test_gpt2_small_config(compute_layout, ctx):
    cfg = GPT2Model.get_cfg()
    cfg.defrost()
    cfg.MODEL.vocab_size = 1000
    cfg.MODEL.units = 128
    cfg.MODEL.num_layers = 2
    cfg.MODEL.num_heads = 2
    cfg.MODEL.compute_layout = compute_layout
    cfg.freeze()

    # Generate TN layout
    cfg_tn = cfg.clone()
    cfg_tn.defrost()
    cfg_tn.MODEL.layout = 'TN'
    cfg_tn.freeze()

    with ctx:
        batch_size = 4
        sequence_length = 16
        inputs = mx.np.random.randint(0,
                                      1000, (batch_size, sequence_length),
                                      ctx=ctx)

        gpt2_model = GPT2Model.from_cfg(cfg)
        gpt2_model.initialize(ctx=ctx)
        gpt2_model.hybridize()
        hiddens, _ = gpt2_model(inputs,
                                gpt2_model.init_states(batch_size, ctx))

        gpt2_model_tn = GPT2Model.from_cfg(cfg_tn)
        gpt2_model_tn.share_parameters(gpt2_model.collect_params())
        gpt2_model_tn.hybridize()
        hiddens_tn, _ = gpt2_model_tn(
            inputs.T, gpt2_model_tn.init_states(batch_size, ctx))
        assert_allclose(np.swapaxes(hiddens_tn.asnumpy(), 0, 1),
                        hiddens.asnumpy(), 1E-4, 1E-4)

        # Test for GPT2ForLM
        gpt2_lm_model = GPT2ForLM(cfg)
        gpt2_lm_model.initialize(ctx=ctx)
        gpt2_lm_model.hybridize()
        logits, states = gpt2_lm_model(
            inputs, gpt2_lm_model.init_states(batch_size, ctx))
        gpt2_lm_model_tn = GPT2ForLM(cfg_tn)
        gpt2_lm_model_tn.share_parameters(gpt2_lm_model.collect_params())
        gpt2_lm_model_tn.hybridize()
        logits_tn, states_tn = gpt2_lm_model_tn(
            inputs.T, gpt2_lm_model_tn.init_states(batch_size, ctx))
        assert_allclose(np.swapaxes(logits_tn.asnumpy(), 0, 1),
                        logits.asnumpy(), 1E-4, 1E-4)
        assert_allclose(np.swapaxes(states_tn.asnumpy(), 2, 3),
                        states.asnumpy(), 1E-4, 1E-4)

        # Verify Float16
        if ctx.device_type == 'gpu':
            verify_backbone_fp16(
                model_cls=GPT2Model,
                cfg=cfg,
                ctx=ctx,
                inputs=[inputs,
                        gpt2_model.init_states(batch_size, ctx)],
                check_amp=False)
            pytest.skip(
                'GPT-2 test has been turned off. '
                'Issue: https://github.com/apache/incubator-mxnet/issues/19463'
            )