def get_kwargs_and_corpus(args):
    # Infer model config
    with open(os.path.join(args.tf_data_dir, 'cache.pkl'), 'rb') as f:
        corpus = pickle.load(f, encoding='latin1')
    tf_checkpoint_file = os.path.expanduser(
        os.path.join(args.tf_checkpoint_dir, args.tf_model_prefix))
    tf_tensors = read_tf_checkpoint(tf_checkpoint_file)
    return to_gluon_kwargs(tf_tensors), corpus
Exemplo n.º 2
0
def convert_transformerxl(args):
    # Load tf model and vocab
    with open(args.cache_pkl, 'rb') as f:
        corpus = pickle.load(f, encoding='latin1')
    vocab = to_gluon_vocab(corpus)
    tf_checkpoint_file = os.path.expanduser(
        os.path.join(args.tf_checkpoint_dir, args.tf_model_prefix))
    tf_tensors = read_tf_checkpoint(tf_checkpoint_file)

    # Initialize Gluon model
    kwargs, tie_r = to_gluon_kwargs(tf_tensors)
    model = TransformerXL(vocab_size=len(vocab), **kwargs)
    model.initialize(init=mx.init.Normal(0.02))

    # Shape inference based on forward pass
    batch_size, seq_len = 2, 16
    mem_length = 100
    mems = model.begin_mems(batch_size, mem_length, context=mx.cpu())
    x = mx.nd.ones(shape=(batch_size, seq_len))
    model(x, x, mems)

    # Convert parameters
    set_params(model, tf_tensors, kwargs, tie_r)

    # Serialization
    tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp'))
    with open(tmp_file_path, 'w') as f:
        f.write(vocab.to_json())
    hash_full, hash_short = get_hash(tmp_file_path)
    gluon_vocab_path = os.path.expanduser(
        os.path.join(args.out_dir, hash_short + '.vocab'))
    with open(gluon_vocab_path, 'w') as f:
        f.write(vocab.to_json())
        logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path,
                     hash_full)
    model.save_parameters(tmp_file_path)
    hash_full, hash_short = get_hash(tmp_file_path)
    os.remove(tmp_file_path)
    gluon_param_path = os.path.expanduser(
        os.path.join(args.out_dir, hash_short + '.params'))
    logging.info('param saved to %s. hash = %s', gluon_param_path, hash_full)
    model.save_parameters(gluon_param_path)
    mx.nd.waitall()
tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp'))
with open(tmp_file_path, 'w') as f:
    f.write(vocab.to_json())
hash_full, hash_short = get_hash(tmp_file_path)
gluon_vocab_path = os.path.expanduser(
    os.path.join(args.out_dir, hash_short + '.vocab'))
with open(gluon_vocab_path, 'w') as f:
    f.write(vocab.to_json())
    logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path,
                 hash_full)

# load tf model
tf_checkpoint_file = os.path.expanduser(
    os.path.join(args.tf_checkpoint_dir, 'bert_model.ckpt'))
logging.info('loading Tensorflow checkpoint %s ...', tf_checkpoint_file)
tf_tensors = read_tf_checkpoint(tf_checkpoint_file)
tf_names = sorted(tf_tensors.keys())
for name in tf_names:
    logging.debug('%s: %s', name, tf_tensors[name].shape)

# replace tensorflow parameter names with gluon parameter names
NAME_MAP = [
    ('bert/encoder/layer_', 'encoder.transformer_cells.'),
    ('/attention/self/', '.attention_cell.'),
    ('key', 'proj_key'),
    ('query', 'proj_query'),
    ('value', 'proj_value'),
    ('/attention/output/LayerNorm/', '.layer_norm.'),
    ('/attention/output/dense/', '.proj.'),
    ('cls/seq_relationship/output_weights', 'classifier.weight'),
    ('cls/seq_relationship/output_bias', 'classifier.bias'),
Exemplo n.º 4
0
def convert_xlnet(args):
    # Load vocab
    vocab_file = os.path.join(args.model_dir, 'spiece.model')
    vocab = nlp.vocab.BERTVocab.from_sentencepiece(vocab_file,
                                                   cls_token='<cls>',
                                                   sep_token='<sep>',
                                                   mask_token='<mask>')

    # Load config
    tf_config_names_to_gluon_config_names = {
        'd_inner': 'hidden_size',
        'd_model': 'units',
        'ff_activation': 'activation',
        'n_head': 'num_heads',
        'n_layer': 'num_layers',
        'n_token': 'vocab_size',
    }
    with open(os.path.join(args.model_dir, 'xlnet_config.json'), 'r') as f:
        tf_config = json.load(f)
        assert tf_config['untie_r']
        del tf_config['untie_r']
        del tf_config['d_head']
        assert len(tf_config) == len(tf_config_names_to_gluon_config_names)
    kwargs = {
        tf_config_names_to_gluon_config_names[k]: v
        for k, v in tf_config.items()
    }
    assert len(vocab) == kwargs['vocab_size']
    print(kwargs)

    # Load TF model
    tf_checkpoint_file = os.path.expanduser(
        os.path.join(args.model_dir, 'xlnet_model.ckpt'))
    tf_tensors = read_tf_checkpoint(tf_checkpoint_file)

    # Update kwargs
    kwargs['tie_decoder_weight'] = 'model/lm_loss/weight' not in tf_tensors

    # Initialize Gluon model
    model = XLNet(**kwargs)
    model.initialize(init=mx.init.Normal(0.02))
    model.hybridize()

    # Shape inference based on forward pass
    batch_size, qlen, mlen = 2, 16, 100
    mems = model.begin_mems(batch_size, mlen, context=mx.cpu())
    x = mx.nd.ones(shape=(batch_size, qlen))
    segments = mx.nd.random_normal(shape=(batch_size, qlen, mlen + qlen, 2))
    segments = segments < 0
    model(x, segments, mems)

    # Convert parameters
    set_params(model, tf_tensors, kwargs, tie_r=False)

    # Serialization
    tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp'))
    with open(tmp_file_path, 'w') as f:
        f.write(vocab.to_json())
    hash_full, hash_short = get_hash(tmp_file_path)
    gluon_vocab_path = os.path.expanduser(
        os.path.join(args.out_dir, hash_short + '.vocab'))
    with open(gluon_vocab_path, 'w') as f:
        f.write(vocab.to_json())
        logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path,
                     hash_full)
    model.save_parameters(tmp_file_path)
    hash_full, hash_short = get_hash(tmp_file_path)
    os.remove(tmp_file_path)
    gluon_param_path = os.path.expanduser(
        os.path.join(args.out_dir, hash_short + '.params'))
    logging.info('param saved to %s. hash = %s', gluon_param_path, hash_full)
    model.save_parameters(gluon_param_path)
    mx.nd.waitall()