Esempio n. 1
0
def ast_if(node):
    if isinstance(node.test, ast.Compare):
        if len(node.test.comparators) == 1 and isinstance(
                node.test.comparators[0], ast.Str):
            if node.test.comparators[0].s == "__main__":
                util_global.set_value("is_main_file", False)
                util_global.set_value("has_main_func", True)
                if util_global.get_value("is_keras_net", False):
                    log_msg(getattr(node, "lineno", "None"),
                            " add keras session npu config")
                    close_sess_call = ast.Call(
                        func=ast.Name(id="close_session", ctx=ast.Load()),
                        args=[ast.Name(id="npu_keras_sess", ctx=ast.Load())],
                        keywords=[])
                    keras_sess_assign = ast.Assign(
                        targets=[
                            ast.Name(id="npu_keras_sess", ctx=ast.Store())
                        ],
                        value=ast.Call(func=ast.Name(
                            id="set_keras_session_npu_config", ctx=ast.Load()),
                                       args=[],
                                       keywords=[]))
                    node.body = [keras_sess_assign] + node.body + [
                        ast.Expr(value=close_sess_call)
                    ]
                    util_global.set_value('need_conver', True)
                if util_global.get_value("has_hccl_api", False):
                    log_msg(getattr(node, "lineno", "None"),
                            " add npu resource init api")
                    close_sess_call = ast.Call(
                        func=ast.Name(id="close_session", ctx=ast.Load()),
                        args=[ast.Name(id="npu_sess", ctx=ast.Load())],
                        keywords=[])
                    init_assign = ast.Assign(targets=[
                        ast.Tuple(elts=[
                            ast.Name(id="npu_sess", ctx=ast.Store()),
                            ast.Name(id="npu_shutdown", ctx=ast.Store())
                        ],
                                  ctx=ast.Store())
                    ],
                                             value=ast.Call(func=ast.Name(
                                                 id="init_resource",
                                                 ctx=ast.Load()),
                                                            args=[],
                                                            keywords=[]))
                    shutdown_call = ast.Call(func=ast.Name(
                        id="shutdown_resource", ctx=ast.Load()),
                                             args=[
                                                 ast.Name(id="npu_sess",
                                                          ctx=ast.Load()),
                                                 ast.Name(id="npu_shutdown",
                                                          ctx=ast.Load())
                                             ],
                                             keywords=[])
                    node.body = [init_assign] + node.body + [
                        ast.Expr(value=shutdown_call),
                        ast.Expr(value=close_sess_call)
                    ]
                    util_global.set_value('need_conver', True)
                return node
Esempio n. 2
0
def log_migration_report(lineno, msg):
    content = (
        util_global.get_value('path', '') + ':' + str(lineno) + ' "' + msg +
        '" feature needs to be migrated manually, Please refer to the migration guide: '
        + util_global.get_value(msg)[0])
    print(content)
    write_conver_report(content, util_global.get_value('report_file')[2])
Esempio n. 3
0
def scan_file(path, file_name, api, lineno):
    api_list = pd.read_excel(util_global.get_value('list'), sheet_name=0)
    api_module = api_list['模块名'].values.tolist()
    api_name = api_list['API名'].values.tolist()
    api_support = api_list['工具迁移API支持度'].values.tolist()
    api_advice = api_list['说明'].values.tolist()

    script_name = []
    code_line = []
    code_module = []
    code_api = []
    support_type = []
    migrate_advice = []

    for i in range(len(api)):
        name = api[i]
        if name in api_name:
            script_name.append(file_name)
            code_api.append(name)
            code_line.append(lineno[i])
            code_module.append(api_module[api_name.index(name)])
            support_type.append(api_support[api_name.index(name)])
            migrate_advice.append(api_advice[api_name.index(name)])

    # search for tf enumeration
    enume_list = pd.read_excel(util_global.get_value('list'), sheet_name=1)
    enume_name = enume_list['API名'].values.tolist()
    (enume, lineno) = get_tf_enume(os.path.join(path, file_name), enume_name)

    for i in range(len(enume)):
        name = enume[i]
        class_name = '.'.join(name.split('.')[:-1])
        if name not in code_api and class_name not in code_api:
            if class_name in api_name:
                script_name.append(file_name)
                code_api.append(class_name)
                code_line.append(lineno[i])
                code_module.append(api_module[api_name.index(class_name)])
                support_type.append(api_support[api_name.index(class_name)])
                migrate_advice.append(api_advice[api_name.index(class_name)])

    # record unsupported api
    (unsupport, unsupport_module, lineno) = get_unsupport_api(os.path.join(path, file_name))
    for i in range(len(unsupport)):
        script_name.append(file_name)
        code_api.append(unsupport[i])
        code_line.append(lineno[i])
        code_module.append(unsupport_module[i])
        support_type.append('不支持(无迁移方案,建议用户不使用)')
        migrate_advice.append('第三方非TF官网API,暂不支持')

    analyse_result = pd.DataFrame({'脚本文件名': script_name, '代码行': code_line,
                                   '模块名': code_module, 'API名': code_api,
                                   '工具迁移API支持度': support_type, '说明': migrate_advice})

    # when there are tf apis used in script, analysis report will be generated
    report = util_global.get_value('generate_dir_report')
    if len(script_name):
        report = report.append(analyse_result)
        util_global.set_value('generate_dir_report', report)
Esempio n. 4
0
def before_clear():
    exit_folder = os.path.exists(util_global.get_value('output'))
    if exit_folder:
        shutil.rmtree(util_global.get_value('output'))
    exit_folder = os.path.exists(util_global.get_value('report'))
    if exit_folder:
        shutil.rmtree(util_global.get_value('report'))
Esempio n. 5
0
def conver():
    """The entry point to convert Tensorflow script"""
    print("Begin conver, input file: " + util_global.get_value('input') + '\n')
    out_path = util_global.get_value('output')
    dst_path = os.path.split(util_global.get_value('input').rstrip('\\/'))[-1]
    dst_path_new = dst_path + util_global.get_value('timestap')
    conver_path = os.walk(util_global.get_value('input'))
    report_dir = util_global.get_value('report')
    mkdir(report_dir)
    report_xlsx = os.path.join(report_dir, 'api_analysis_report.xlsx')
    util_global.set_value('generate_dir_report', pd.DataFrame())

    for path, _, file_list in conver_path:
        for file_name in file_list:
            out_path_dst = abs_join(
                dst_path_new,
                path.split(util_global.get_value('input'))[1])
            file_path = os.path.join(path, file_name).replace('\\', '/')
            if not check_path_length(file_path):
                content = "".join([
                    "The file:", file_path, " length is invalid, skip convert."
                ])
                log_warning(content)
                continue
            content = "".join(["Begin conver file: ", file_path])
            print(content)
            threshold_file_size = 10 * 1024 * 1024
            if file_name.endswith(".py"):
                if os.path.getsize(file_path) > threshold_file_size:
                    content = "".join([
                        "The file:", file_path,
                        " size is over 10M, skip convert."
                    ])
                    log_warning(content)
                    continue
                util_global.set_value('path', file_path)
                mkdir(os.path.join(out_path, out_path_dst))
                conver_ast(path, out_path_dst, file_name)
                if util_global.get_value('need_conver', False):
                    content = "".join(
                        ["Finish conver file: ", file_path, '\n'])
                    print(content)
                    write_report_terminator(content)
                else:
                    mkdir_and_copyfile(path, abs_join(out_path, out_path_dst),
                                       file_name)
            else:
                mkdir_and_copyfile(path, abs_join(out_path, out_path_dst),
                                   file_name)

    adjust_index()
    analysis_report = util_global.get_value('generate_dir_report')
    if analysis_report.empty:
        print('No api data in the report')
    else:
        analysis_report.to_excel(report_xlsx, index=True)
        get_api_statistic(analysis_report)
    print("Finish conver, output file: " + out_path + "; report file: " +
          util_global.get_value('report'))
Esempio n. 6
0
def attribute(node):
    log_success_report(getattr(node, "lineno", "None"), node.attr)
    if node.attr == 'dropout':
        node.value = ast.Name(id=util_global.get_value(node.attr)[0], ctx=ast.Load())
    else:
        node = ast.Name(id=util_global.get_value(node.attr)[0], ctx=ast.Load())
    util_global.set_value('need_conver', True)
    return node
Esempio n. 7
0
def log_success_report(lineno, msg):
    content = (util_global.get_value('path', '') + ':' + str(lineno) +
               ' change ' + util_global.get_value(msg)[1] + ' to ' +
               util_global.get_value(msg)[2])
    print(content)
    write_conver_report(content, util_global.get_value('report_file')[0])
    util_global.set_value('report_file_status',
                          (util_global.get_value('report_file_status') | 0b1))
Esempio n. 8
0
def write_report_terminator(content):
    report_path = util_global.get_value('report')
    for file in util_global.get_value('report_file'):
        if os.path.exists(os.path.join(report_path, file)):
            file = open(os.path.join(report_path, file), 'a')
            file.write(content)
            file.write("\r\n")
            file.write("\r\n")
            file.close()
Esempio n. 9
0
def conver_ast(path, out_path_dst, file_name):
    util_global.set_value('need_conver', False)
    util_global.set_value('is_keras_net', False)
    util_global.set_value('has_hccl_api', False)
    util_global.set_value('is_main_file', False)
    util_global.set_value('has_main_func', False)
    if os.path.join(path, file_name) == util_global.get_value('main', ""):
        util_global.set_value('is_main_file', True)
    with open(os.path.join(path, file_name), "r", encoding='utf-8') as file:
        source = file.read()
    try:
        r_node = pasta.parse(source)
    except Exception as e:
        print(repr(e))
        return

    sys.setrecursionlimit(10000)
    visitor = ConverByAst()
    visitor.visit(r_node)
    ast.fix_missing_locations(r_node)

    (api, lineno) = get_tf_api(os.path.join(path, file_name))
    if len(api) == 0:
        print(
            "No Tensorflow module is imported in script {}.".format(file_name))
    scan_file(path, file_name, api, lineno)

    if util_global.get_value('need_conver', False):
        insert_npu_import(r_node)
        if not util_global.get_value('has_main_func', False) and (
                util_global.get_value('has_hccl_api', False)
                or util_global.get_value('is_keras_net', False)):
            log_warning(
                'the network of keras and horovod, or using dataset.shard script do not have main func, '
                'should set -m or --main parameter')
        if util_global.get_value('is_main_file',
                                 False) and util_global.get_value(
                                     'has_hccl_api', False):
            insert_npu_resource_init(r_node)
            insert_npu_resource_shutdown(r_node)
        if util_global.get_value('is_main_file',
                                 False) and util_global.get_value(
                                     'is_keras_net', False):
            insert_keras_sess_npu_config(r_node)
            insert_keras_sess_close(r_node)
        dst_content = pasta.dump(r_node)
        write_output_after_conver(
            os.path.join(util_global.get_value('output'), out_path_dst,
                         file_name), dst_content)

    if file_name.endswith("a.py"):
        write_report_after_conver("only_for_test", file_name,
                                  node_tree(ast.dump(r_node)))
Esempio n. 10
0
 def visit_Attribute(self, node):
     """Visit and transform attr node"""
     self.generic_visit(node)
     if node.attr == "keras":
         util_global.set_value('is_keras_net', True)
     if node.attr in util_global.get_value('hvd'):
         distributed_mode = util_global.get_value("distributed_mode", "")
         if isinstance(node.value, ast.Name) and 'hvd' in str(node.value.id):
             if distributed_mode in ("tf_strategy", ""):
                 log_strategy_distributed_mode_error(node)
                 return node
             return attribute(node)
     return node
Esempio n. 11
0
def write_report_terminator(content):
    report_path = util_global.get_value('report')
    value = util_global.get_value('report_file_status')
    times = value.bit_length()
    while times > 0:
        if get_bit_val(value, times - 1):
            file = util_global.get_value('report_file')[times - 1]
            if os.path.exists(os.path.join(report_path, file)):
                file = open(os.path.join(report_path, file), 'a')
                file.write(content)
                file.write("\r\n")
                file.write("\r\n")
                file.close()
        times = times - 1
    util_global.set_value('report_file_status', 0)
Esempio n. 12
0
def write_report_terminator(content):
    """Write content to report and update global variable"""
    report_path = util_global.get_value('report')
    value = util_global.get_value('report_file_status')
    times = value.bit_length()
    while times > 0:
        if get_bit_val(value, times - 1):
            file = util_global.get_value('report_file')[times - 1]
            if os.path.exists(os.path.join(report_path, file)):
                with open(os.path.join(report_path, file), 'a') as file:
                    file.write(content)
                    file.write("\r\n")
                    file.write("\r\n")
        times = times - 1
    util_global.set_value('report_file_status', 0)
Esempio n. 13
0
def convert_loss_scale_api(node):
    """Convert loss scale related Tensorflow APIs"""
    if isinstance(node.func, ast.Attribute):
        if node.func.attr == "FixedLossScale":
            log_msg(
                getattr(node, 'lineno', 'None'),
                "change tf.train.experimental.FixedLossScale"
                " to FixedLossScaleManager")
            node.func = ast.Name(id="FixedLossScaleManager", ctx=ast.Load())
            if len(node.keywords) == 1:
                node.keywords[0].arg = "loss_scale"
            util_global.set_value('need_conver', True)
            return node
        if node.func.attr == "DynamicLossScale":
            return convert_dynamic_loss_scale(node)
        if node.func.attr == "MixedPrecisionLossScaleOptimizer":
            log_msg(
                getattr(node, 'lineno', 'None'),
                "change tf.train.experimental.MixedPrecisionLossScaleOptimizer"
                " to NPULossScaleOptimizer")
            node.func = ast.Name(id="NPULossScaleOptimizer", ctx=ast.Load())
            for keyword in node.keywords:
                if keyword.arg == "loss_scale":
                    keyword.arg = "loss_scale_manager"
            if (len(util_global.get_value("distributed_mode", "")) != 0):
                node.keywords.append(
                    ast.keyword(arg="is_distributed",
                                value=pasta.parse("True")))
            util_global.set_value('need_conver', True)
            return node
Esempio n. 14
0
def convert_origin_func_to_npu(node,
                               origin_func,
                               org_func_name,
                               params_list,
                               is_class_func=None):
    """Convert original Tensorflow function to NPU function"""
    if not check_func_arguments(origin_func, node.args, node.keywords,
                                is_class_func):
        return node
    if org_func_name == "Estimator.train":
        content = "".join([
            util_global.get_value('path'), ":",
            str(getattr(node, "lineno", "None"))
        ])
        while True:
            message = input(
                "Check if the train function in " + content +
                " is the Estimator train function. If yes, "
                "enter 'y' to perform distributed porting on the train function. if no, enter 'n': "
            )
            if message == "y":
                break
            if message == "n":
                log_warning("".join([
                    "The train func in ", content,
                    " is user-defined functions, will not perform distributed porting"
                ]))
                return node
            print("Input is error, Please enter 'y' or 'n'.")
    for param_name in params_list:
        node = match_func_params_and_convert(node, origin_func, org_func_name,
                                             param_name, is_class_func)

    util_global.set_value('need_conver', True)
    return node
Esempio n. 15
0
def write_conver_report(content, file):
    report_path = util_global.get_value('report')
    mkdir(report_path)
    file = open(os.path.join(report_path, file), 'a')
    file.write(content)
    file.write("\r\n")
    file.close()
Esempio n. 16
0
def write_conver_report(content, file):
    """Add content to existed report file"""
    report_path = util_global.get_value('report')
    mkdir(report_path)
    with open(os.path.join(report_path, file), 'a') as f:
        f.write(content)
        f.write("\r\n")
Esempio n. 17
0
def adjust_index():
    report = util_global.get_value('generate_dir_report')
    index_column = []
    for i in range(len(report)):
        index_column.append(i + 1)
    report.index = index_column
    report.index.name = '序号'
    util_global.set_value('generate_dir_report', report)
Esempio n. 18
0
 def visit_Attribute(self, node):
     if node.attr in util_global.get_value('nn') and isinstance(
             node.value, ast.Attribute):
         if node.value.attr == 'nn':
             return attribute(node)
     if node.attr in util_global.get_value('estimator') and isinstance(
             node.value, ast.Attribute):
         if node.value.attr == 'estimator':
             return attribute(node)
     if node.attr in util_global.get_value('hvd'):
         if isinstance(node.value, ast.Name):
             if 'hvd' in str(node.value.id):
                 return attribute(node)
         if isinstance(node.value, ast.Attribute):
             if 'hvd' in str(node.value.attr):
                 return attribute(node)
     return node
Esempio n. 19
0
 def get_output_dir(self):
     """Get selected output directory"""
     output = "output" + util_global.get_value('timestap')
     if self.output_path.get():
         output = self.output_path.get()
         if str(output).endswith('/'):
             output = output[:-1]
         output = output.replace('\\', '/')
     return output
Esempio n. 20
0
def conver():
    print("Begin conver, input file: " + util_global.get_value('input') + '\n')
    out_path = util_global.get_value('output')
    dst_path = os.path.split(util_global.get_value('input').rstrip('\\/'))[-1]
    dst_path_new = dst_path + util_global.get_value('timestap')
    conver_path = os.walk(util_global.get_value('input'))
    report_dir = util_global.get_value('report')
    mkdir(report_dir)
    report_xlsx = os.path.join(report_dir, 'api_analysis_report.xlsx')
    util_global.set_value('generate_dir_report', pd.DataFrame())

    for path, dir_list, file_list in conver_path:
        for file_name in file_list:
            out_path_dst = abs_join(
                dst_path_new,
                path.split(util_global.get_value('input'))[1])
            file_path = os.path.join(path, file_name).replace('\\', '/')
            content = "Begin conver file: " + file_path
            print(content)
            if file_name.endswith(".py"):
                util_global.set_value('path', file_path)
                mkdir(os.path.join(out_path, out_path_dst))
                conver_ast(path, out_path_dst, file_name)
                if util_global.get_value('need_conver', False):
                    content = "Finish conver file: " + file_path + '\n'
                    print(content)
                    write_report_terminator(content)
                else:
                    mkdir_and_copyfile(path, abs_join(out_path, out_path_dst),
                                       file_name)
            else:
                mkdir_and_copyfile(path, abs_join(out_path, out_path_dst),
                                   file_name)

    adjust_index()
    analysis_report = util_global.get_value('generate_dir_report')
    if analysis_report.empty:
        print('No api data in the report')
    else:
        analysis_report.to_excel(report_xlsx, index=True)
        get_api_statistic(analysis_report)
    print("Finish conver, output file: " + out_path + "; report file: " +
          util_global.get_value('report'))
Esempio n. 21
0
def add_npu_func_to_params(node, param_index, org_func_name, param_name,
                           npu_func, npu_func_args):
    """Add npu function to parameters"""
    param_node = None
    if ((not util_global.get_value("distributed_mode", "")
         or util_global.get_value("distributed_mode", "") == "horovod")
            and (param_name in ("callbacks", "hooks", "optimizer"))):
        return node
    log_param_msg = "".join([org_func_name, " add npu ", param_name])
    log_msg(getattr(node, "lineno", "None"), log_param_msg)
    for index, _ in enumerate(node.args):
        if param_index is not None and index == param_index:
            param_node = node.args.pop(param_index)

    for keyword in node.keywords:
        if keyword.arg == param_name:
            param_node = keyword

    if param_node:
        if isinstance(param_node, ast.keyword):
            new_value = ast.Call(func=ast.Name(id=npu_func, ctx=ast.Load()),
                                 args=[],
                                 keywords=[
                                     ast.keyword(arg=npu_func_args,
                                                 value=param_node.value)
                                 ])
            ast.copy_location(new_value, param_node.value)
            param_node.value = new_value
        else:
            node.keywords.append(
                ast.keyword(arg=param_name,
                            value=ast.Call(func=ast.Name(id=npu_func,
                                                         ctx=ast.Load()),
                                           args=[],
                                           keywords=[
                                               ast.keyword(arg=npu_func_args,
                                                           value=param_node)
                                           ])))
    else:
        node.keywords.append(
            ast.keyword(arg=param_name,
                        value=pasta.parse("".join([npu_func, "()"]))))
    return node
Esempio n. 22
0
def conver(r_node, out_path_dst, file_name):
    """Add necessary imported modules"""
    if file_name != "__init__.py":
        insert_npu_import(r_node)
    if util_global.get_value('use_keras_dropout', False):
        insert_keras_dropout_import(r_node)
    distributed_mode = util_global.get_value('distributed_mode', "")
    if not util_global.get_value('has_main_func', False) and \
            (util_global.get_value('has_hvd_api', False) or
             util_global.get_value('is_keras_net', False)) and \
            not util_global.get_value('main', ""):
        log_warning_main_arg_not_set()
    if distributed_mode == "horovod" and util_global.get_value('is_main_file', False):
        insert_npu_resource_init(r_node)
        insert_npu_resource_shutdown(r_node)
    if util_global.get_value('is_main_file', False) and util_global.get_value('is_keras_net', False):
        insert_keras_sess_npu_config(r_node)
        insert_keras_sess_close(r_node)
    dst_content = pasta.dump(r_node)
    write_output_after_conver(os.path.join(util_global.get_value('output'), out_path_dst, file_name), dst_content)
Esempio n. 23
0
 def get_report_dir(self):
     """Get selected report directory"""
     report = "report" + util_global.get_value('timestap')
     report_suffix = report
     if self.report_path.get():
         report = self.report_path.get()
         if str(report).endswith('/'):
             report = report[:-1]
         report = os.path.join(report, report_suffix)
         report = report.replace('\\', '/')
     return report
Esempio n. 24
0
def ast_function_def(node):
    log_success_report(getattr(node, "lineno", "None"), node.name)
    node.body = [ast.Return(value=ast.Call(
                                            func=ast.Attribute(value=ast.Name(id=util_global.get_value(node.name)[0],
                                                               ctx=ast.Load()), attr='gelu',
                                                               ctx=ast.Load()),
                                            args=[ast.Name(id='x', ctx=ast.Load())],
                                            keywords=[]))]

    util_global.set_value('need_conver', True)
    return node
Esempio n. 25
0
def conver_ast(path, out_path_dst, file_name):
    util_global.set_value('need_conver', False)
    file = open(os.path.join(path, file_name), "r")
    source = file.read()
    r_node = ast.parse(source)

    sys.setrecursionlimit(10000)
    visitor = ConverByAst()
    visitor.visit(r_node)
    ast.fix_missing_locations(r_node)

    if util_global.get_value('need_conver', False):
        insert_npu_import(r_node)
        dst_content = astunparse.unparse(r_node)
        write_output_after_conver(
            os.path.join(util_global.get_value('output'), out_path_dst,
                         file_name), dst_content)

    if file_name.endswith("a.py"):
        write_report_after_conver("only_for_test", file_name,
                                  node_tree(ast.dump(r_node)))
Esempio n. 26
0
 def visit_Attribute(self, node):
     self.generic_visit(node)
     if node.attr == "keras":
         util_global.set_value('is_keras_net', True)
     if node.attr in util_global.get_value('hvd'):
         if isinstance(node.value, ast.Name):
             if 'hvd' in str(node.value.id):
                 return attribute(node)
         if isinstance(node.value, ast.Attribute):
             if 'hvd' in str(node.value.attr):
                 return attribute(node)
     return node
Esempio n. 27
0
def conver():
    print("Begin conver, input file: " + util_global.get_value('input'))
    out_path = util_global.get_value('output')
    dst_path = os.path.split(util_global.get_value('input').rstrip('\\/'))[-1]
    conver_path = os.walk(util_global.get_value('input'))
    for path, dir_list, file_list in conver_path:
        for file_name in file_list:
            out_path_dst = abs_join(dst_path, path.split(dst_path)[1])
            if file_name.endswith(".py"):
                util_global.set_value('path', os.path.join(path, file_name))
                mkdir(os.path.join(out_path, out_path_dst))
                conver_ast(path, out_path_dst, file_name)
                if util_global.get_value('need_conver', False):
                    content = "Finish conver file: " + os.path.join(
                        path, file_name)
                    print(content)
                    write_report_terminator(content)
                else:
                    mkdir_and_copyfile(path, abs_join(out_path, out_path_dst),
                                       file_name)
            else:
                mkdir_and_copyfile(path, abs_join(out_path, out_path_dst),
                                   file_name)

    print("Finish conver, output file: " + out_path + "; report file: " +
          util_global.get_value('report'))
Esempio n. 28
0
def log_failed_api(lineno, api_msg, is_third_party):
    """Log message for NPU unsupported APIs"""
    os.system("cd .")
    if is_third_party:
        content = "".join([
            util_global.get_value('path', ''), ":",
            str(lineno), ", NPU Unsupport API: ", api_msg,
            ", Please modify user scripts manually."
        ])
        print("".join(["\033[1;31mERROR\033[0m:", content]), flush=True)

    elif api_msg.startswith("hvd"):
        doc_msg = "{}, chapter: {}".format('"Tensorflow模型迁移和训练',
                                           '"Horovod脚本迁移示例"')
        content = "".join([
            util_global.get_value('path', ''), ":",
            str(lineno), ", NPU Unsupport API: ", api_msg,
            ", Please refer to the online document: ", doc_msg
        ])
        print("".join(["\033[1;33mWARNING\033[0m:", content]), flush=True)

    elif api_msg.startswith("tf.is_"):
        doc_msg = "{}, chapter: {}".format(
            '"Tensorflow模型迁移和训练', '"tf.is_finite接口手工迁移" and "Loss Scale"')
        content = "".join([
            util_global.get_value('path', ''), ":",
            str(lineno), ", NPU Unsupport API: ", api_msg,
            ", Please refer to the online document: ", doc_msg
        ])
        print("".join(["\033[1;33mWARNING\033[0m:", content]), flush=True)

    else:
        content = "".join([
            util_global.get_value('path', ''), ":",
            str(lineno), ", NPU Unsupport API: ", api_msg
        ])
        print("".join(["\033[1;31mERROR\033[0m:", content]), flush=True)
    logger_failed_report.info(content)
Esempio n. 29
0
def conver_ast(path, out_path_dst, file_name):
    """Convert script by python ast"""
    util_global.set_value('need_conver', False)
    util_global.set_value('is_keras_net', False)
    util_global.set_value('has_hvd_api', False)
    util_global.set_value('is_main_file', False)
    util_global.set_value('has_main_func', False)
    if os.path.join(path, file_name) == util_global.get_value('main', ""):
        util_global.set_value('is_main_file', True)
    with open(os.path.join(path, file_name), "r", encoding='utf-8') as file:
        source = file.read() + "\n"
    try:
        r_node = pasta.parse(source)
    except Exception as e:
        print(repr(e))
        content = ("There is a format problem in the script, please check the python code "
                   "specification or whether it is converted to a linux file through 'dos2unix'")
        os.system("cd .")
        print("".join(["\033[1;31mERROR\033[0m:", content]))
        return

    sys.setrecursionlimit(10000)
    visitor = ConverByAst()
    visitor.visit(r_node)
    ast.fix_missing_locations(r_node)

    (api, lineno) = get_tf_api(os.path.join(path, file_name))
    if len(api) == 0:
        print("No Tensorflow module is imported in script {}.".format(file_name))
    scan_file(path, file_name, api, lineno)

    if util_global.get_value('need_conver', False):
        conver(r_node, out_path_dst, file_name)

    if file_name.endswith("a.py"):
        write_report_after_conver("only_for_test", file_name, node_tree(ast.dump(r_node)))
Esempio n. 30
0
def convert_tf_gradient_distributed(node):
    """Convert Tensorflow gradient APIs in distributed mode"""
    content = "".join([
        util_global.get_value('path'), ":",
        str(getattr(node, "lineno", "None")),
        " is tf.gradient api, tool inserts npu_allreduce after computing grads by default.",
        " You can adjust the allreduce position according to the algorithm"
    ])
    log_warning(content)
    new_node = ast.Call(func=ast.Name(id="npu_allreduce", ctx=ast.Load()),
                        args=[node],
                        keywords=[])
    ast.copy_location(new_node, node)
    util_global.set_value("need_conver", True)
    return new_node