def test_bart_cfg(cfg_key, ctx): cfg = BartModel.get_cfg(cfg_key) cfg.defrost() cfg.MODEL.vocab_size = 32 cfg.freeze() cfg_tn = cfg.clone() cfg_tn.defrost() cfg_tn.MODEL.layout = 'TN' cfg_tn.freeze() batch_size = 4 src_length = 32 tgt_length = 16 with ctx: src_data = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, src_length), dtype=np.int32) src_valid_length = mx.np.random.randint(src_length // 2, src_length, (batch_size, ), dtype=np.int32) tgt_data = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, tgt_length), dtype=np.int32) tgt_valid_length = mx.np.random.randint(tgt_length // 2, tgt_length, (batch_size, ), dtype=np.int32) model = BartModel.from_cfg(cfg, extract_feature=True) model.initialize() model.hybridize() contextual_embedding, pooled_output = model(src_data, src_valid_length, tgt_data, tgt_valid_length) model_tn = BartModel.from_cfg(cfg_tn, extract_feature=True) model_tn.share_parameters(model.collect_params()) model_tn.hybridize() contextual_embedding_tn, pooled_out_tn = model_tn( src_data.T, src_valid_length, tgt_data.T, tgt_valid_length) npt.assert_allclose( contextual_embedding.asnumpy(), np.transpose(contextual_embedding_tn.asnumpy(), (1, 0, 2)), 5E-3, 5E-3) npt.assert_allclose(pooled_out_tn.asnumpy(), pooled_output.asnumpy(), 5E-3, 5E-3) mx.npx.waitall() # Verify Float16 if ctx.device_type == 'gpu': verify_backbone_fp16(model_cls=BartModel, cfg=cfg, ctx=ctx, inputs=[ src_data, src_valid_length, tgt_data, tgt_valid_length ])
def test_bart_cfg(cfg_key): cfg = BartModel.get_cfg(cfg_key) cfg.defrost() cfg.MODEL.vocab_size = 32 cfg.freeze() model = BartModel.from_cfg(cfg) model.initialize() model.hybridize() cfg.defrost() cfg.MODEL.layout = 'TN' cfg.freeze() model_tn = BartModel.from_cfg(cfg) model_tn.share_parameters(model.collect_params()) model_tn.hybridize() mx.npx.waitall()
def convert_fairseq_model(args): if not args.save_dir: args.save_dir = os.path.basename(args.fairseq_model_path) + '_gluon' if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) fairseq_bart = fairseq_BARTModel.from_pretrained( args.fairseq_model_path, checkpoint_file='model.pt') vocab_size = convert_vocab(args, fairseq_bart) gluon_cfg = convert_config(fairseq_bart.args, vocab_size, BartModel.get_cfg().clone()) with open(os.path.join(args.save_dir, 'model.yml'), 'w') as of: of.write(gluon_cfg.dump()) ctx = mx.gpu(args.gpu) if args.gpu is not None else mx.cpu() gluon_bart = convert_params(fairseq_bart, gluon_cfg, ctx) if args.test: test_model(fairseq_bart, gluon_bart, args.gpu) gluon_bart.save_parameters(os.path.join(args.save_dir, 'model.params'), deduplicate=True) logging.info('Convert the BART MLM model in {} to {}'.format( os.path.join(args.fairseq_model_path, 'model.pt'), os.path.join(args.save_dir, 'model.params'))) logging.info('Conversion finished!') logging.info('Statistics:') old_names = os.listdir(args.save_dir) for old_name in old_names: new_name, long_hash = naming_convention(args.save_dir, old_name) old_path = os.path.join(args.save_dir, old_name) new_path = os.path.join(args.save_dir, new_name) shutil.move(old_path, new_path) file_size = os.path.getsize(new_path) logging.info('\t{}/{} {} {}'.format(args.save_dir, new_name, long_hash, file_size))