def convert_tf2tf(dst_params_file_path, dst_params, dst_param_keys, src_params, src_param_keys): import re src_param_keys = [key.replace('/W:', '/kernel:') for key in src_param_keys] src_param_keys = [key.replace('/b:', '/bias:') for key in src_param_keys] src_param_keys = [ key.replace('linear/', 'output/') for key in src_param_keys ] src_param_keys = [ key.replace('stage', 'features/stage') for key in src_param_keys ] src_param_keys = [ re.sub('^conv1/', 'features/init_block/conv/', key) for key in src_param_keys ] src_param_keys = [ re.sub('^conv5/', 'features/final_block/conv/', key) for key in src_param_keys ] src_param_keys = [ key.replace('/dconv_bn/', '/dconv/bn/') for key in src_param_keys ] src_param_keys = [ key.replace('/shortcut_dconv_bn/', '/shortcut_dconv/bn/') for key in src_param_keys ] src_param_keys.sort() src_param_keys.sort(key=lambda var: [ '{:10}'.format(int(x)) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var) ]) dst_param_keys.sort() dst_param_keys.sort(key=lambda var: [ '{:10}'.format(int(x)) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var) ]) src_param_keys = [key.replace('/kernel:', '/W:') for key in src_param_keys] src_param_keys = [key.replace('/bias:', '/b:') for key in src_param_keys] src_param_keys = [ key.replace('output/', 'linear/') for key in src_param_keys ] src_param_keys = [ key.replace('features/stage', 'stage') for key in src_param_keys ] src_param_keys = [ key.replace('features/init_block/conv/', 'conv1/') for key in src_param_keys ] src_param_keys = [ key.replace('features/final_block/conv/', 'conv5/') for key in src_param_keys ] src_param_keys = [ key.replace('/dconv/bn/', '/dconv_bn/') for key in src_param_keys ] src_param_keys = [ key.replace('/shortcut_dconv/bn/', '/shortcut_dconv_bn/') for key in src_param_keys ] assert (len(src_param_keys) == len(dst_param_keys)) import tensorflow as tf with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i, (src_key, dst_key) in enumerate(zip(src_param_keys, dst_param_keys)): assert (src_params[src_key].shape == tuple( dst_params[dst_key].get_shape().as_list())) sess.run(dst_params[dst_key].assign(src_params[src_key])) from tensorflow_.utils import save_model_params save_model_params(sess=sess, file_path=dst_params_file_path)
def convert_gl2tf(dst_params_file_path, dst_params, dst_param_keys, src_params, src_param_keys): dst_param_keys = [ key.replace('/kernel:', '/weight:') for key in dst_param_keys ] dst_param_keys = [ key.replace('/dw_kernel:', '/weight_dw:') for key in dst_param_keys ] dst_param_keys = [ key.replace('/post_activ/', '/stageN/post_activ/') for key in dst_param_keys ] dst_param_keys = [ key.replace('/final_block/', '/stageN/final_block/') for key in dst_param_keys ] dst_param_keys = [ key.replace('/stem1_unit/', '/stage0/stem1_unit/') for key in dst_param_keys ] dst_param_keys = [ key.replace('/stem2_unit/', '/stage0/stem2_unit/') for key in dst_param_keys ] src_param_keys.sort() src_param_keys.sort(key=lambda var: [ '{:10}'.format(int(x)) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var) ]) dst_param_keys.sort() dst_param_keys.sort(key=lambda var: [ '{:10}'.format(int(x)) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var) ]) dst_param_keys = [ key.replace('/weight:', '/kernel:') for key in dst_param_keys ] dst_param_keys = [ key.replace('/weight_dw:', '/dw_kernel:') for key in dst_param_keys ] dst_param_keys = [ key.replace('/stageN/post_activ/', '/post_activ/') for key in dst_param_keys ] dst_param_keys = [ key.replace('/stageN/final_block/', '/final_block/') for key in dst_param_keys ] dst_param_keys = [ key.replace('/stage0/stem1_unit/', '/stem1_unit/') for key in dst_param_keys ] dst_param_keys = [ key.replace('/stage0/stem2_unit/', '/stem2_unit/') for key in dst_param_keys ] dst_param_keys_orig = dst_param_keys.copy() dst_param_keys = [ s[:(s.find("convgroup") + 9)] + "/" + s.split('/')[-1] if s.find("convgroup") >= 0 else s for s in dst_param_keys ] dst_param_keys_uniq, dst_param_keys_index = np.unique(dst_param_keys, return_index=True) dst_param_keys = list(dst_param_keys_uniq[dst_param_keys_index.argsort()]) assert (len(src_param_keys) == len(dst_param_keys)) import tensorflow as tf with tf.Session() as sess: sess.run(tf.global_variables_initializer()) def process_width(src_key, dst_key, src_weight): if len(src_weight.shape) == 4: if dst_key.split("/")[-1][:-2] == "dw_kernel": src_weight = np.transpose(src_weight, axes=(2, 3, 0, 1)) else: src_weight = np.transpose(src_weight, axes=(2, 3, 1, 0)) elif len(src_weight.shape) == 2: src_weight = np.transpose(src_weight, axes=(1, 0)) assert (tuple( dst_params[dst_key].get_shape().as_list()) == src_weight.shape) sess.run(dst_params[dst_key].assign(src_weight)) # print(dst_params[dst_key].eval(sess)) for i, (src_key, dst_key) in enumerate(zip(src_param_keys, dst_param_keys)): if dst_key.find("convgroup") >= 0: dst_key_stem = dst_key[:(dst_key.find("convgroup") + 9)] dst_keys = [ s for s in dst_param_keys_orig if s.startswith(dst_key_stem) ] if src_key.endswith("weight"): dst_keys = [s for s in dst_keys if s.endswith("kernel:0")] elif src_key.endswith("bias"): dst_keys = [s for s in dst_keys if s.endswith("bias:0")] groups = len(dst_keys) src_weight0 = src_params[src_key]._data[0] src_weight0_list = mx.nd.split(src_weight0, axis=0, num_outputs=groups) for gi in range(groups): src_weight_gi = src_weight0_list[gi].asnumpy() dst_key_gi = dst_keys[gi] process_width(src_key, dst_key_gi, src_weight_gi) else: src_weight = src_params[src_key]._data[0].asnumpy() process_width(src_key, dst_key, src_weight) # saver = tf.train.Saver() # saver.save( # sess=sess, # save_path=dst_params_file_path) from tensorflow_.utils import save_model_params save_model_params(sess=sess, file_path=dst_params_file_path)