def save_model(new_gluon_parameters, output_dir):
    print('save model start'.center(60, '='))
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # save model
    # load vocab
    vocab_f = open(os.path.join(output_dir, "vocab.txt"),
                   "wt",
                   encoding='utf-8')
    with open(args.ernie_vocab_path, "rt", encoding='utf-8') as f:
        for line in f:
            data = line.strip().split("\t")
            vocab_f.writelines(data[0] + "\n")
    vocab_f.close()
    vocab = tf_vocab_to_gluon_vocab(
        load_text_vocab(os.path.join(output_dir, "vocab.txt")))
    # vocab serialization
    tmp_file_path = os.path.expanduser(os.path.join(output_dir, 'tmp'))
    if not os.path.exists(os.path.join(args.out_dir)):
        os.makedirs(os.path.join(args.out_dir))
    with open(tmp_file_path, 'w') as f:
        f.write(vocab.to_json())
    hash_full, hash_short = get_hash(tmp_file_path)
    gluon_vocab_path = os.path.expanduser(
        os.path.join(output_dir, hash_short + '.vocab'))
    with open(gluon_vocab_path, 'w') as f:
        f.write(vocab.to_json())
        logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path,
                     hash_full)

    # BERT config
    tf_config_names_to_gluon_config_names = {
        'attention_probs_dropout_prob': 'dropout',
        'hidden_act': None,
        'hidden_dropout_prob': 'dropout',
        'hidden_size': 'units',
        'initializer_range': None,
        # 'intermediate_size': 'hidden_size',
        'max_position_embeddings': 'max_length',
        'num_attention_heads': 'num_heads',
        'num_hidden_layers': 'num_layers',
        'type_vocab_size': 'token_type_vocab_size',
        'vocab_size': None
    }
    predefined_args = bert_hparams[args.gluon_bert_model_base]
    with open(args.ernie_config_path, 'r') as f:
        tf_config = json.load(f)
        if 'layer_norm_eps' in tf_config:  # ignore layer_norm_eps
            del tf_config['layer_norm_eps']
        assert len(tf_config) == len(tf_config_names_to_gluon_config_names)
        for tf_name, gluon_name in tf_config_names_to_gluon_config_names.items(
        ):
            if tf_name is None or gluon_name is None:
                continue
            if gluon_name != 'max_length':
                assert tf_config[tf_name] == predefined_args[gluon_name]

    encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'],
                          num_layers=predefined_args['num_layers'],
                          units=predefined_args['units'],
                          hidden_size=predefined_args['hidden_size'],
                          max_length=predefined_args['max_length'],
                          num_heads=predefined_args['num_heads'],
                          scaled=predefined_args['scaled'],
                          dropout=predefined_args['dropout'],
                          use_residual=predefined_args['use_residual'],
                          activation='relu')

    bert = BERTModel(
        encoder,
        len(vocab),
        token_type_vocab_size=predefined_args['token_type_vocab_size'],
        units=predefined_args['units'],
        embed_size=predefined_args['embed_size'],
        word_embed=predefined_args['word_embed'],
        use_pooler=True,
        use_decoder=False,
        use_classifier=False)

    bert.initialize(init=mx.init.Normal(0.02))

    ones = mx.nd.ones((2, 8))
    out = bert(ones, ones, mx.nd.array([5, 6]), mx.nd.array([[1], [2]]))
    params = bert._collect_params_with_prefix()
    assert len(params) == len(new_gluon_parameters), "Gluon model does not match paddle model. " \
                                                   "Please fix the BERTModel hyperparameters"

    # post processings for parameters:
    # - handle tied decoder weight
    new_gluon_parameters['decoder.3.weight'] = new_gluon_parameters[
        'word_embed.0.weight']
    # set parameter data
    loaded_params = {}
    for name in params:
        if name == 'word_embed.0.weight':
            arr = mx.nd.array(
                new_gluon_parameters[name][:params[name].shape[0]])
        else:
            arr = mx.nd.array(new_gluon_parameters[name])
        try:
            assert arr.shape == params[name].shape
        except:
            print(name)
        params[name].set_data(arr)
        loaded_params[name] = True

    # post processings for parameters:
    # - handle tied decoder weight
    # - update word embedding for reserved tokens

    if len(params) != len(loaded_params):
        raise RuntimeError(
            'The Gluon BERTModel comprises {} parameter arrays, '
            'but {} have been extracted from the paddle model. '.format(
                len(params), len(loaded_params)))

    # param serialization
    bert.save_parameters(tmp_file_path)
    hash_full, hash_short = get_hash(tmp_file_path)
    gluon_param_path = os.path.expanduser(
        os.path.join(args.out_dir, hash_short + '.params'))
    logging.info('param saved to %s. hash = %s', gluon_param_path, hash_full)
    bert.save_parameters(gluon_param_path)
    mx.nd.waitall()
    # save config
    print('finish save vocab')
    print('save model done!'.center(60, '='))
parser.add_argument('--out', default='gluon_to_pytorch_naming.json',
                    help='Output file to store gluon to pytorch name mapping.')
args = parser.parse_args()
logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO)
logging.info(args)

# Load Gluon Model
bert, vocab = nlp.model.get_model(args.model, dataset_name=args.dataset_name, pretrained=True)
parameters = bert._collect_params_with_prefix()
parameters = {k: v.data().asnumpy() for k, v in parameters.items()}

# Load PyTorch Model
pytorch_parameters = torch.load(os.path.join(args.pytorch_checkpoint_dir, 'pytorch_model.bin'),
                                map_location=lambda storage, loc: storage)
pytorch_vocab = tf_vocab_to_gluon_vocab(
    load_text_vocab(os.path.join(args.pytorch_checkpoint_dir, 'vocab.txt')))
pytorch_parameters = {k: v.numpy() for k, v in pytorch_parameters.items()}

# Assert that vocabularies are equal
assert pytorch_vocab.idx_to_token == vocab.idx_to_token

mapping = dict()

for name, param in parameters.items():
    found_match = False
    for pytorch_name, pytorch_param in pytorch_parameters.items():
        if param.shape == pytorch_param.shape:
            if (param == pytorch_param).all():
                if found_match:
                    print('Found multiple matches for {}. '
                          'Ignoring new match {}'.format(name, pytorch_name))
Esempio n. 3
0
parser.add_argument('--tf_config_name',
                    type=str,
                    default='bert_config.json',
                    help='Name of Bert config file')
parser.add_argument('--out_dir',
                    type=str,
                    default=os.path.join('~', 'output'),
                    help='Path to output folder. The folder must exist.')
parser.add_argument('--debug', action='store_true', help='debugging mode')
args = parser.parse_args()
logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO)
logging.info(args)

# convert vocabulary
vocab_path = os.path.join(args.tf_model_path, 'vocab.txt')
vocab = tf_vocab_to_gluon_vocab(load_text_vocab(vocab_path))

# vocab serialization
out_dir = os.path.expanduser(args.out_dir)
nlp.utils.mkdir(out_dir)
gluon_vocab_path = os.path.join(out_dir, 'tf.vocab')
with open(gluon_vocab_path, 'w') as f:
    f.write(vocab.to_json())
    logging.info('vocab file saved to %s.', gluon_vocab_path)

# load tf model from pb file
tf_pb_file = os.path.join(args.tf_model_path, 'model.pb')
logging.info('loading Tensorflow pb file %s ...', tf_pb_file)
tf_tensors = read_tf_pb(tf_pb_file)
tf_names = sorted(tf_tensors.keys())
parser.add_argument(
    '--gluon_pytorch_name_mapping',
    type=str,
    default='gluon_to_pytorch_naming.json',
    help='Output of infer_pytorch_gluon_parameter_name_mapping.py')
parser.add_argument('--out_dir',
                    type=str,
                    default=os.path.join('~', 'output'),
                    help='Path to output folder. The folder must exist.')
parser.add_argument('--debug', action='store_true', help='debugging mode')
args = parser.parse_args()
logging.getLogger().setLevel(logging.DEBUG if args.debug else logging.INFO)
logging.info(args)

# convert vocabulary
vocab = tf_vocab_to_gluon_vocab(load_text_vocab(args.vocab_file))

# vocab serialization
tmp_file_path = os.path.expanduser(os.path.join(args.out_dir, 'tmp'))
with open(tmp_file_path, 'w') as f:
    f.write(vocab.to_json())
hash_full, hash_short = get_hash(tmp_file_path)
gluon_vocab_path = os.path.expanduser(
    os.path.join(args.out_dir, hash_short + '.vocab'))
with open(gluon_vocab_path, 'w') as f:
    f.write(vocab.to_json())
    logging.info('vocab file saved to %s. hash = %s', gluon_vocab_path,
                 hash_full)

# Load PyTorch Model
pytorch_parameters = torch.load(os.path.join(args.pytorch_checkpoint_dir,
model.save_pretrained(dir_name)
tokenizer.save_pretrained(dir_name)

####################################################################
#                  SHOW PYTORCH PARAMETER LIST                     #
####################################################################
pytorch_parameters = torch.load(os.path.join(dir_name, 'pytorch_model.bin'))
print('parameters in pytorch')
print(sorted(list(pytorch_parameters)))

####################################################################
#                        CONVERT VOCAB                             #
####################################################################
# convert vocabulary
vocab = tf_vocab_to_gluon_vocab(
    load_text_vocab(os.path.join(dir_name, 'vocab.txt')))
# vocab serialization
tmp_file_path = os.path.expanduser(os.path.join(gluon_dir_name, 'temp'))
with open(tmp_file_path, 'w') as f:
    f.write(vocab.to_json())

hash_full, hash_short = get_hash(tmp_file_path)
gluon_vocab_path = os.path.expanduser(
    os.path.join(gluon_dir_name, hash_short + '.vocab'))
with open(gluon_vocab_path, 'w') as f:
    f.write(vocab.to_json())
    print('vocab file saved to {}. hash = {}'.format(gluon_vocab_path,
                                                     hash_full))

####################################################################
#                       CONVERT PARAMS OPTIONS                     #