Пример #1
0
def pytorch_to_nnabla(input_file, h5_file):
    read = torch.load(input_file)
    for k, v in read.items():
        key = rename_params(k)
        params = PF.get_parameter_or_create(key, shape=v.shape)
        params.d = v.numpy()
    nn.parameter.save_parameters(h5_file)
Пример #2
0
def parse_tf_ckpt(ckpt_file, h5_file):
    ''' Parse the TF checkpoint file and save as nnabla parameters
    '''
    # Get tensorflow checkpoint reader
    reader = pywrap_tensorflow.NewCheckpointReader(ckpt_file)
    var_to_shape_map = reader.get_variable_to_shape_map()

    # Loop through each tensor name from the variable to shape map
    for key in sorted(var_to_shape_map):

        # Read tensor values for each tensor name
        weight = reader.get_tensor(key)
        if 'depthwise' in key and weight.ndim == 4:
            weight = numpy.squeeze(weight, axis=3)
            weight = numpy.transpose(weight, (2, 0, 1))

        else:
            if (weight.ndim == 4):
                weight = numpy.transpose(weight, (3, 2, 0, 1))

        if 'BatchNorm' in key:
            weight = weight.reshape((1, -1, 1, 1))

        if 'Momentum' in key or 'ExponentialMovingAverage' in key:
            continue

        key = rename_params(key)

        # Create parameter with the same tensor name and shape
        params = pf.get_parameter_or_create(key, shape=weight.shape)
        params.d = weight

    # Save to a h5 file
    nn.parameter.save_parameters(h5_file)
def convert_ckpt_to_h5(input_file, h5_file):
    """
    Convert the input checkpoint file to output hdf5 file
    """
    # Get tensorflow checkpoint reader
    reader = pywrap_tensorflow.NewCheckpointReader(input_file)
    var_to_shape_map = reader.get_variable_to_shape_map()

    # Loop through each tensor name from the variable to shape map
    for key in var_to_shape_map:
        # Read tensor values for each tensor name
        weight = reader.get_tensor(key)
        if not str(key).startswith("vgg_19/mean_rgb") and not str(
                key).startswith("global_step"):
            s = key.split('/')
            if str(s[-2]).startswith("fc"):
                k = ('/'.join(s))
                key = rename_params(str(k), affine=True)
            else:
                s.remove(s[1])
                k = ('/'.join(s))
                key = rename_params(str(k), affine=False)
            if weight.ndim == 4:
                # transpose TF weight to NNabla weight format
                weight = np.transpose(weight, (3, 0, 1, 2))

                # Create parameter with the same tensor name and shape
            params = PF.get_parameter_or_create(key, shape=weight.shape)
            params.d = weight

        # Save to a h5 file
    nn.parameter.save_parameters(h5_file)
Пример #4
0
def pytorch_to_nnabla(input_file, h5_file):
    read = torch.load(input_file)
    for k, v in read.items():
        if not str(k).startswith("classifier"):
            key = rename_params(str(k), affine=False)
            key = key.replace("/", "", 1)
        else:
            key = rename_params(str(k), affine=True)
        params = pf.get_parameter_or_create(key, shape=v.shape)
        params.d = v.numpy()
    nn.parameter.save_parameters(h5_file)
def pytorch_to_nnabla(input_file, h5_file):
    read = torch.load(input_file)
    for k, v in read.items():
        split = k.split('.')[-2]
        if split.startswith('bn') or split.startswith('0'):
            key = rename_params(k, conv=False)
            v = v.reshape((1, ) + v.shape + (1, 1))
        else:
            key = rename_params(k, conv=True)
        params = PF.get_parameter_or_create(key, shape=v.shape)
        params.d = v.cpu().numpy()
    nn.parameter.save_parameters(h5_file)
Пример #6
0
def pytorch_to_nnabla(input_file, h5_file):
    read = torch.load(input_file)
    for k, v in read['state_dict'].items():
        k = k.replace('module.', '')
        split = k.split('.')[-2]
        if split.startswith('bn') or split.startswith('1'):
            key = rename_params(k, conv=False)
            v = v.reshape((1, ) + v.shape + (1, 1))
        else:
            if k == 'fc.weight':
                v = v.T
            key = rename_params(k, conv=True)
        params = PF.get_parameter_or_create(key, shape=v.shape)
        params.d = v.cpu().numpy()
    nn.parameter.save_parameters(h5_file)
def convert(ckpt_file, h5_file):
    """
    Convert BERT Tensorflow weights to NNabla

    Args:
        ckpt_file: Input Tensorflow ckpt file
        h5_file: Output NNabla output file 

    """
    # Check the TensorFlow version for compatibility
    if int(tf.__version__[0]) == 2:
        reader = tf.compat.v1.train.NewCheckpointReader(ckpt_file)
    else:
        reader = pywrap_tensorflow.NewCheckpointReader(ckpt_file)
    var_to_shape_map = reader.get_variable_to_shape_map()

    for key in sorted(var_to_shape_map):
        weight = reader.get_tensor(key)
        if 'encoder' in key:
            layer_id = int(key.split('/')[2].replace('layer_', ''))
            if 'query/bias' in key:
                key = 'encoder{:02d}/transformer_encode/src_self_attn/multi_head_attention/q_bias'.format(
                    layer_id)
            if 'query/kernel' in key:
                key = 'encoder{:02d}/transformer_encode/src_self_attn/multi_head_attention/q_weight'.format(
                    layer_id)
            if 'key/bias' in key:
                key = 'encoder{:02d}/transformer_encode/src_self_attn/multi_head_attention/k_bias'.format(
                    layer_id)
            if 'key/kernel' in key:
                key = 'encoder{:02d}/transformer_encode/src_self_attn/multi_head_attention/k_weight'.format(
                    layer_id)
            if 'value/bias' in key:
                key = 'encoder{:02d}/transformer_encode/src_self_attn/multi_head_attention/v_bias'.format(
                    layer_id)
            if 'value/kernel' in key:
                key = 'encoder{:02d}/transformer_encode/src_self_attn/multi_head_attention/v_weight'.format(
                    layer_id)
            if 'attention/output/LayerNorm/beta' in key:
                key = 'encoder{:02d}/transformer_encode/enc_layer_norm1/layer_normalization/beta'.format(
                    layer_id)
                weight = numpy.reshape(weight, (1, 1, 768))
            elif 'output/LayerNorm/beta' in key:
                key = 'encoder{:02d}/transformer_encode/enc_layer_norm2/layer_normalization/beta'.format(
                    layer_id)
                weight = numpy.reshape(weight, (1, 1, 768))
            if 'attention/output/LayerNorm/gamma' in key:
                key = 'encoder{:02d}/transformer_encode/enc_layer_norm1/layer_normalization/gamma'.format(
                    layer_id)
                weight = numpy.reshape(weight, (1, 1, 768))
            elif 'output/LayerNorm/gamma' in key:
                key = 'encoder{:02d}/transformer_encode/enc_layer_norm2/layer_normalization/gamma'.format(
                    layer_id)
                weight = numpy.reshape(weight, (1, 1, 768))
            if 'attention/output/dense/bias' in key:
                key = 'encoder{:02d}/transformer_encode/src_self_attn/multi_head_attention/out_bias'.format(
                    layer_id)
            elif 'output/dense/bias' in key:
                key = 'encoder{:02d}/transformer_encode/enc_affine2/affine/b'.format(
                    layer_id)
            if 'intermediate/dense/bias' in key:
                key = 'encoder{:02d}/transformer_encode/enc_affine1/affine/b'.format(
                    layer_id)
            if 'attention/output/dense/kernel' in key:
                key = 'encoder{:02d}/transformer_encode/src_self_attn/multi_head_attention/out_weight'.format(
                    layer_id)
            elif 'output/dense/kernel' in key:
                key = 'encoder{:02d}/transformer_encode/enc_affine2/affine/W'.format(
                    layer_id)
            if 'intermediate/dense/kernel' in key:
                key = 'encoder{:02d}/transformer_encode/enc_affine1/affine/W'.format(
                    layer_id)
        if 'embeddings/LayerNorm/' in key:
            key = key.replace('bert/embeddings/LayerNorm',
                              'embed/layer_normalization')
            weight = numpy.reshape(weight, (1, 1, 768))
        if 'word_embeddings' in key:
            key = 'word_embeddings/embed/W'
        if 'token_type_embeddings' in key:
            key = 'token_type_embeddings/embed/W'
        if 'position_embeddings' in key:
            key = 'position_embeddings/embed/W'
        if 'pooler/dense/bias' in key:
            key = 'pooler/affine/b'
        if 'pooler/dense/kernel' in key:
            key = 'pooler/affine/W'
        if 'seq_relationship/output_weights' in key:
            key = 'affine_seq_class/affine/W'
            weight = numpy.transpose(weight)
        if 'seq_relationship/output_bias' in key:
            key = 'affine_seq_class/affine/b'

        params = PF.get_parameter_or_create(key, shape=weight.shape)
        params.d = weight

    nn.parameter.save_parameters(h5_file)