コード例 #1
0
def test_bart_cfg(cfg_key, ctx):
    cfg = BartModel.get_cfg(cfg_key)
    cfg.defrost()
    cfg.MODEL.vocab_size = 32
    cfg.freeze()

    cfg_tn = cfg.clone()
    cfg_tn.defrost()
    cfg_tn.MODEL.layout = 'TN'
    cfg_tn.freeze()

    batch_size = 4
    src_length = 32
    tgt_length = 16

    with ctx:
        src_data = mx.np.random.randint(0,
                                        cfg.MODEL.vocab_size,
                                        (batch_size, src_length),
                                        dtype=np.int32)
        src_valid_length = mx.np.random.randint(src_length // 2,
                                                src_length, (batch_size, ),
                                                dtype=np.int32)
        tgt_data = mx.np.random.randint(0,
                                        cfg.MODEL.vocab_size,
                                        (batch_size, tgt_length),
                                        dtype=np.int32)
        tgt_valid_length = mx.np.random.randint(tgt_length // 2,
                                                tgt_length, (batch_size, ),
                                                dtype=np.int32)
        model = BartModel.from_cfg(cfg, extract_feature=True)
        model.initialize()
        model.hybridize()

        contextual_embedding, pooled_output = model(src_data, src_valid_length,
                                                    tgt_data, tgt_valid_length)
        model_tn = BartModel.from_cfg(cfg_tn, extract_feature=True)
        model_tn.share_parameters(model.collect_params())
        model_tn.hybridize()
        contextual_embedding_tn, pooled_out_tn = model_tn(
            src_data.T, src_valid_length, tgt_data.T, tgt_valid_length)
        npt.assert_allclose(
            contextual_embedding.asnumpy(),
            np.transpose(contextual_embedding_tn.asnumpy(), (1, 0, 2)), 5E-3,
            5E-3)
        npt.assert_allclose(pooled_out_tn.asnumpy(), pooled_output.asnumpy(),
                            5E-3, 5E-3)
        mx.npx.waitall()

        # Verify Float16
        if ctx.device_type == 'gpu':
            verify_backbone_fp16(model_cls=BartModel,
                                 cfg=cfg,
                                 ctx=ctx,
                                 inputs=[
                                     src_data, src_valid_length, tgt_data,
                                     tgt_valid_length
                                 ])
コード例 #2
0
ファイル: test_models_bart.py プロジェクト: zheyuye/gluon-nlp
def test_bart_cfg(cfg_key):
    cfg = BartModel.get_cfg(cfg_key)
    cfg.defrost()
    cfg.MODEL.vocab_size = 32
    cfg.freeze()
    model = BartModel.from_cfg(cfg)
    model.initialize()
    model.hybridize()
    cfg.defrost()
    cfg.MODEL.layout = 'TN'
    cfg.freeze()
    model_tn = BartModel.from_cfg(cfg)
    model_tn.share_parameters(model.collect_params())
    model_tn.hybridize()
    mx.npx.waitall()
コード例 #3
0
def convert_fairseq_model(args):
    if not args.save_dir:
        args.save_dir = os.path.basename(args.fairseq_model_path) + '_gluon'
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    fairseq_bart = fairseq_BARTModel.from_pretrained(
        args.fairseq_model_path, checkpoint_file='model.pt')
    vocab_size = convert_vocab(args, fairseq_bart)
    gluon_cfg = convert_config(fairseq_bart.args, vocab_size,
                               BartModel.get_cfg().clone())
    with open(os.path.join(args.save_dir, 'model.yml'), 'w') as of:
        of.write(gluon_cfg.dump())

    ctx = mx.gpu(args.gpu) if args.gpu is not None else mx.cpu()
    gluon_bart = convert_params(fairseq_bart, gluon_cfg, ctx)
    if args.test:
        test_model(fairseq_bart, gluon_bart, args.gpu)

    gluon_bart.save_parameters(os.path.join(args.save_dir, 'model.params'),
                               deduplicate=True)
    logging.info('Convert the BART MLM model in {} to {}'.format(
        os.path.join(args.fairseq_model_path, 'model.pt'),
        os.path.join(args.save_dir, 'model.params')))

    logging.info('Conversion finished!')
    logging.info('Statistics:')
    old_names = os.listdir(args.save_dir)
    for old_name in old_names:
        new_name, long_hash = naming_convention(args.save_dir, old_name)
        old_path = os.path.join(args.save_dir, old_name)
        new_path = os.path.join(args.save_dir, new_name)
        shutil.move(old_path, new_path)
        file_size = os.path.getsize(new_path)
        logging.info('\t{}/{} {} {}'.format(args.save_dir, new_name, long_hash,
                                            file_size))