Exemple #1
0
def convert_tf_checkpoint_to_pytorch():
    # gave error originally. Solution found at: https://github.com/tensorflow/models/issues/2676
    tf_path = 'weights/biobert_v1.1_pubmed/model.ckpt-1000000'
    init_vars = tf.train.list_variables(tf_path)
    excluded = ['BERTAdam', '_power', 'global_step']
    init_vars = list(filter(lambda x: all([True if e not in x[0] else False for e in excluded]), init_vars))
    print(init_vars)

    names = []
    arrays = []

    for name, shape in init_vars:
        print("Loading TF weights {} with shape {}".format(name,shape))
        array = tf.train.load_variable(tf_path,name)
        names.append(name)
        arrays.append(array)

    config = BertConfig.from_json_file('weights/biobert_v1.1_pubmed/bert_config.json')
    print('Building Pytorch model from configuration {}'.format(str(config)))
    model = BertForPreTraining(config)

    for name, array in zip(names,arrays):
        name = name.split('/')
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
        if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
            print("Skipping {}".format("/".join(name)))
            continue
        pointer = model
        for m_name in name:
            if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
                l = re.split(r'_(\d+)', m_name)
            else:
                l = [m_name]
            if l[0] == 'kernel' or l[0] == 'gamma':
                pointer = getattr(pointer, 'weight')
            elif l[0] == 'output_bias' or l[0] == 'beta':
                pointer = getattr(pointer, 'bias')
            elif l[0] == 'output_weights':
                pointer = getattr(pointer, 'weight')
            else:
                pointer = getattr(pointer, l[0])
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
        if m_name[-11:] == '_embeddings':
            pointer = getattr(pointer, 'weight')
        elif m_name == 'kernel':
            array = np.transpose(array)
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        print("Initialize PyTorch weight {}".format(name))
        pointer.data = torch.from_numpy(array)

    # Save pytorch-model
    print("Save PyTorch model to {}".format('weights/'))
    torch.save(model.state_dict(), 'weights/pytorch_weight')
Exemple #2
0
def load_BFTC_from_TF_ckpt(bert_config, ckpt_path, num_labels):
    """
    Helper function for loading model - workaround to prevent error
    """
    config = BertConfig.from_json_file(bert_config)
    model = BertForPreTraining(config)
    load_tf_weights_in_bert(model, ckpt_path)
    state_dict=model.state_dict()
    model = BertForTokenClassification(config, num_labels=num_labels)

    # Load from a PyTorch state_dict
    old_keys = []
    new_keys = []
    for key in state_dict.keys():
        new_key = None
        if 'gamma' in key:
            new_key = key.replace('gamma', 'weight')
        if 'beta' in key:
            new_key = key.replace('beta', 'bias')
        if new_key:
            old_keys.append(key)
            new_keys.append(new_key)
    for old_key, new_key in zip(old_keys, new_keys):
        state_dict[new_key] = state_dict.pop(old_key)

    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')
    start_prefix = ''
    if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
        start_prefix = 'bert.'
    load(model, prefix=start_prefix)
    if len(missing_keys) > 0:
        print("Weights of {} not initialized from pretrained model: {}".format(
            model.__class__.__name__, missing_keys))
    if len(unexpected_keys) > 0:
        print("Weights from pretrained model not used in {}: {}".format(
            model.__class__.__name__, unexpected_keys))
    if len(error_msgs) > 0:
        raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
                           model.__class__.__name__, "\n\t".join(error_msgs)))
    return model
Exemple #3
0
def _convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # adapated from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py#L30
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    _load_tf_weights_in_bert(model, tf_checkpoint_path)

    # Save pytorch-model
    torch.save(model.state_dict(), pytorch_dump_path)
Exemple #4
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,
                                     pytorch_dump_path):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file,
                                     pytorch_dump_path):
    config_path = os.path.abspath(bert_config_file)
    tf_path = os.path.abspath(tf_checkpoint_path)
    print("Converting TensorFlow checkpoint from {} with config at {}".format(
        tf_path, config_path))
    # Load weights from TF model
    init_vars = tf.train.list_variables(tf_path)
    excluded = ['BERTAdam', '_power', 'global_step']
    init_vars = list(
        filter(
            lambda x: all([True if e not in x[0] else False
                           for e in excluded]), init_vars))
    names = []
    arrays = []
    for name, shape in init_vars:
        print("Loading TF weight {} with shape {}".format(name, shape))
        array = tf.train.load_variable(tf_path, name)
        names.append(name)
        arrays.append(array)

    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)

    for name, array in zip(names, arrays):
        name = name.split('/')
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
        if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
            print("Skipping {}".format("/".join(name)))
            continue
        pointer = model
        for m_name in name:
            if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
                l = re.split(r'_(\d+)', m_name)
            else:
                l = [m_name]
            if l[0] == 'kernel' or l[0] == 'gamma':
                pointer = getattr(pointer, 'weight')
            elif l[0] == 'output_bias' or l[0] == 'beta':
                pointer = getattr(pointer, 'bias')
            elif l[0] == 'output_weights':
                pointer = getattr(pointer, 'weight')
            else:
                pointer = getattr(pointer, l[0])
            if len(l) >= 2:
                num = int(l[1])
                pointer = pointer[num]
        if m_name[-11:] == '_embeddings':
            pointer = getattr(pointer, 'weight')
        elif m_name == 'kernel':
            array = np.transpose(array)
        try:
            assert pointer.shape == array.shape
        except AssertionError as e:
            e.args += (pointer.shape, array.shape)
            raise
        print("Initialize PyTorch weight {}".format(name))
        pointer.data = torch.from_numpy(array)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
    for m_name in name:
        if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
            l = re.split(r'_(\d+)', m_name)
        else:
            l = [m_name]
        if l[0] == 'kernel' or l[0] == 'gamma':
            pointer = getattr(pointer, 'weight')
        elif l[0] == 'output_bias' or l[0] == 'beta':
            pointer = getattr(pointer, 'bias')
        elif l[0] == 'output_weights':
            pointer = getattr(pointer, 'weight')
        else:
            pointer = getattr(pointer, l[0])
        if len(l) >= 2:
            num = int(l[1])
            pointer = pointer[num]
    if m_name[-11:] == '_embeddings':
        pointer = getattr(pointer, 'weight')
    elif m_name == 'kernel':
        array = np.transpose(array)
    try:
        assert pointer.shape == array.shape
    except AssertionError as e:
        e.args += (pointer.shape, array.shape)
        raise
    print("Initialize PyTorch weight {}".format(name))
    pointer.data = torch.from_numpy(array)

print("Save PyTorch model to {}".format('weights/'))
torch.save(model.state_dict(),'biobert_pubmed/pytorch_model.bin')