def get_kwargs_and_corpus(args): # Infer model config with open(os.path.join(args.tf_data_dir, 'cache.pkl'), 'rb') as f: corpus = pickle.load(f, encoding='latin1') tf_checkpoint_file = os.path.expanduser( os.path.join(args.tf_checkpoint_dir, args.tf_model_prefix)) tf_tensors = read_tf_checkpoint(tf_checkpoint_file) return to_gluon_kwargs(tf_tensors), corpus
def convert_transformerxl(args): # Load tf model and vocab with open(args.cache_pkl, 'rb') as f: corpus = pickle.load(f, encoding='latin1') vocab = to_gluon_vocab(corpus) tf_checkpoint_file = os.path.expanduser( os.path.join(args.tf_checkpoint_dir, args.tf_model_prefix)) tf_tensors = read_tf_checkpoint(tf_checkpoint_file) # Initialize Gluon model kwargs, tie_r = to_gluon_kwargs(tf_tensors) model = TransformerXL(vocab_size=len(vocab), **kwargs) model.initialize(init=mx.init.Normal(0.02)) # Shape inference based on forward pass batch_size, seq_len = 2, 16 mem_length = 100 mems = model.begin_mems(batch_size, mem_length, context=mx.cpu()) x = mx.nd.ones(shape=(batch_size, seq_len)) model(x, x, mems) # Convert parameters set_params(model, tf_tensors, kwargs, tie_r) # Serialization tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp')) with open(tmp_file_path, 'w') as f: f.write(vocab.to_json()) hash_full, hash_short = get_hash(tmp_file_path) gluon_vocab_path = os.path.expanduser( os.path.join(args.out_dir, hash_short + '.vocab')) with open(gluon_vocab_path, 'w') as f: f.write(vocab.to_json()) logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path, hash_full) model.save_parameters(tmp_file_path) hash_full, hash_short = get_hash(tmp_file_path) os.remove(tmp_file_path) gluon_param_path = os.path.expanduser( os.path.join(args.out_dir, hash_short + '.params')) logging.info('param saved to %s. hash = %s', gluon_param_path, hash_full) model.save_parameters(gluon_param_path) mx.nd.waitall()
tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp')) with open(tmp_file_path, 'w') as f: f.write(vocab.to_json()) hash_full, hash_short = get_hash(tmp_file_path) gluon_vocab_path = os.path.expanduser( os.path.join(args.out_dir, hash_short + '.vocab')) with open(gluon_vocab_path, 'w') as f: f.write(vocab.to_json()) logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path, hash_full) # load tf model tf_checkpoint_file = os.path.expanduser( os.path.join(args.tf_checkpoint_dir, 'bert_model.ckpt')) logging.info('loading Tensorflow checkpoint %s ...', tf_checkpoint_file) tf_tensors = read_tf_checkpoint(tf_checkpoint_file) tf_names = sorted(tf_tensors.keys()) for name in tf_names: logging.debug('%s: %s', name, tf_tensors[name].shape) # replace tensorflow parameter names with gluon parameter names NAME_MAP = [ ('bert/encoder/layer_', 'encoder.transformer_cells.'), ('/attention/self/', '.attention_cell.'), ('key', 'proj_key'), ('query', 'proj_query'), ('value', 'proj_value'), ('/attention/output/LayerNorm/', '.layer_norm.'), ('/attention/output/dense/', '.proj.'), ('cls/seq_relationship/output_weights', 'classifier.weight'), ('cls/seq_relationship/output_bias', 'classifier.bias'),
def convert_xlnet(args): # Load vocab vocab_file = os.path.join(args.model_dir, 'spiece.model') vocab = nlp.vocab.BERTVocab.from_sentencepiece(vocab_file, cls_token='<cls>', sep_token='<sep>', mask_token='<mask>') # Load config tf_config_names_to_gluon_config_names = { 'd_inner': 'hidden_size', 'd_model': 'units', 'ff_activation': 'activation', 'n_head': 'num_heads', 'n_layer': 'num_layers', 'n_token': 'vocab_size', } with open(os.path.join(args.model_dir, 'xlnet_config.json'), 'r') as f: tf_config = json.load(f) assert tf_config['untie_r'] del tf_config['untie_r'] del tf_config['d_head'] assert len(tf_config) == len(tf_config_names_to_gluon_config_names) kwargs = { tf_config_names_to_gluon_config_names[k]: v for k, v in tf_config.items() } assert len(vocab) == kwargs['vocab_size'] print(kwargs) # Load TF model tf_checkpoint_file = os.path.expanduser( os.path.join(args.model_dir, 'xlnet_model.ckpt')) tf_tensors = read_tf_checkpoint(tf_checkpoint_file) # Update kwargs kwargs['tie_decoder_weight'] = 'model/lm_loss/weight' not in tf_tensors # Initialize Gluon model model = XLNet(**kwargs) model.initialize(init=mx.init.Normal(0.02)) model.hybridize() # Shape inference based on forward pass batch_size, qlen, mlen = 2, 16, 100 mems = model.begin_mems(batch_size, mlen, context=mx.cpu()) x = mx.nd.ones(shape=(batch_size, qlen)) segments = mx.nd.random_normal(shape=(batch_size, qlen, mlen + qlen, 2)) segments = segments < 0 model(x, segments, mems) # Convert parameters set_params(model, tf_tensors, kwargs, tie_r=False) # Serialization tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp')) with open(tmp_file_path, 'w') as f: f.write(vocab.to_json()) hash_full, hash_short = get_hash(tmp_file_path) gluon_vocab_path = os.path.expanduser( os.path.join(args.out_dir, hash_short + '.vocab')) with open(gluon_vocab_path, 'w') as f: f.write(vocab.to_json()) logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path, hash_full) model.save_parameters(tmp_file_path) hash_full, hash_short = get_hash(tmp_file_path) os.remove(tmp_file_path) gluon_param_path = os.path.expanduser( os.path.join(args.out_dir, hash_short + '.params')) logging.info('param saved to %s. hash = %s', gluon_param_path, hash_full) model.save_parameters(gluon_param_path) mx.nd.waitall()