示例#1
0
def convert_kerastf_2_script(model_path=None, out_format=None):
    """
    Convert model from keras.model to a script that allows to create the model
    using pytorch
    model_path - string, path to keras.model. model_path should point to a copy 
    of the model in the temp folder to make sure the relative path does not 
    contain spaces " ". Spaces cause isses due to a command-line interface of mmdnn.
    outpath
    out_format: string, indicating the target format. This can be "tensorflow","pytorch", "caffe","cntk","mxnet","onnx"
    """
    tf.reset_default_graph()  #Make sure to start with a fresh session
    sess = tf.InteractiveSession()

    temp_path = aid_bin.create_temp_folder(
    )  #create a temp folder if it does not already exist
    #Create a  random filename for a temp. file
    tmp_model = np.random.choice(list("STERNBURGPILS"), 5, replace=True)
    tmp_model = "".join(tmp_model) + ".model"
    tmp_model = os.path.join(temp_path, tmp_model)
    shutil.copyfile(model_path, tmp_model)  #copy the original model file there
    #Get the relative path to temp to make sure there are no spaces " " in the path
    relpath = os.path.relpath(tmp_model, start=os.curdir)

    #Keras to intermediate representation (IR): convert the temp. file
    parser = convertToIR._get_parser()
    dstPath = os.path.relpath(tmp_model, start=os.curdir)
    dstPath = os.path.splitext(dstPath)[0]  #remove the .model file extension
    string = "--srcFramework keras --weights " + relpath + " --dstPath " + dstPath
    args = parser.parse_args(string.split())
    convertToIR._convert(args)

    #IR to Code
    dstModelPath = dstPath + ".py"
    dstWeightPath = dstPath + ".npy"
    IRModelPath = dstPath + ".pb"
    parser = IRToCode._get_parser()
    string = "--dstFramework " + out_format + " --dstModelPath " + dstModelPath + " --IRModelPath " + IRModelPath + " --IRWeightPath " + dstWeightPath + " --dstWeightPath " + dstWeightPath
    args = parser.parse_args(string.split())
    IRToCode._convert(args)

    #Copy the final output script and weights back to the original folder
    out_script = os.path.splitext(
        model_path)[0] + ".py"  #remove .model, put .py instead
    out_weights = os.path.splitext(
        model_path)[0] + ".npy"  #remove .model, put .py instead
    shutil.copyfile(dstModelPath, out_script)  #copy from temp to original path
    shutil.copyfile(dstWeightPath,
                    out_weights)  #copy from temp to original path

    #delete all the temp. files
    del_json = os.path.splitext(relpath)[0] + ".json"
    for file in [IRModelPath, dstModelPath, dstWeightPath, relpath, del_json]:
        try:
            os.remove(file)
            print("temp. file deleted: " + file)
        except:
            print("temp. file not found/ not deleted: " + file)
    sess.close()
示例#2
0
def _main():
    parser = _get_parser()
    args, unknown_args = parser.parse_known_args()
    temp_filename = uuid.uuid4().hex
    ir_args, unknown_args = _extract_ir_args(args, unknown_args, temp_filename)
    ret = convertToIR._convert(ir_args)
    if int(ret) != 0:
        _sys.exit(int(ret))
    if args.dstFramework != 'coreml':
        network_filename = get_network_filename(args.dstFramework,
                                                temp_filename,
                                                args.outputModel)
        code_args, unknown_args = _extract_code_args(args, unknown_args,
                                                     temp_filename,
                                                     network_filename)
        ret = IRToCode._convert(code_args)
        if int(ret) != 0:
            _sys.exit(int(ret))
        from mmdnn.conversion._script.dump_code import dump_code
        dump_code(args.dstFramework, network_filename + '.py',
                  temp_filename + '.npy', args.outputModel, args.dump_tag)
        #remove_temp_files(temp_filename)

    else:
        model_args, unknown_args = _extract_model_args(args, unknown_args,
                                                       temp_filename)
        ret = IRToModel._convert(model_args)
        #remove_temp_files(temp_filename)
        _sys.exit(int(ret))
示例#3
0
def _extract_code_args(args, unknown_args, temp_filename, network_filename):
    unknown_args.extend(['--dstFramework', args.dstFramework])
    unknown_args.extend(['--IRModelPath', temp_filename + '.pb'])
    unknown_args.extend(['--IRWeightPath', temp_filename + '.npy'])
    unknown_args.extend(['--dstModelPath', network_filename + '.py'])
    unknown_args.extend(['--dstWeightPath', temp_filename + '.npy'])
    code_parser = IRToCode._get_parser()
    return code_parser.parse_known_args(unknown_args)
示例#4
0
def _extract_code_args(args, unknown_args, temp_filename):
    unknown_args.extend(['--dstFramework', args.dstFramework])
    unknown_args.extend(['--IRModelPath', temp_filename + '.pb'])
    unknown_args.extend(['--IRWeightPath', temp_filename + '.npy'])
    unknown_args.extend(['--dstModelPath', temp_filename + '.py'])
    unknown_args.extend(['--dstWeightPath', temp_filename + '.npy'])
    code_parser = IRToCode._get_parser()
    return code_parser.parse_known_args(unknown_args)
示例#5
0
def _main():
    parser = _get_parser()
    args, unknown_args = parser.parse_known_args()
    temp_filename = uuid.uuid4().hex
    ir_args, unknown_args = _extract_ir_args(args, unknown_args, temp_filename)
    ret = convertToIR._convert(ir_args)
    if int(ret) != 0:
        _sys.exit(int(ret))
    if args.dstFramework != 'coreml':
        code_args, unknown_args = _extract_code_args(args, unknown_args, temp_filename)
        ret = IRToCode._convert(code_args)
        if int(ret) != 0:
            _sys.exit(int(ret))
        from mmdnn.conversion._script.dump_code import dump_code
        dump_code(args.dstFramework, temp_filename + '.py', temp_filename + '.npy', args.outputModel)
        remove_temp_files(temp_filename)

    elif args.dstType == 'model':
        model_args, unknown_args = _extract_model_args(args, unknown_args, temp_filename)
        ret = IRToModel._convert(model_args)
        remove_temp_files(temp_filename)
        _sys.exit(int(ret))