예제 #1
0
def build_script(args, cmd_run=None):
    script = Script(summary=generate_summary(args.model_file, args.runners,
                                             args.load_results))
    tool_util.add_logger_settings(script, args)

    data_loader_name = tool_util.add_data_loader(script, args)

    for runner_arg in args.runners:
        add_runner_func = {
            "tf":
            add_tf_runner,
            "onnxrt":
            add_onnxrt_runner,
            "onnxtf":
            add_onnxtf_runner,
            "cntk":
            add_cntk_runner,
            "trt":
            lambda script, args: add_trt_runner(script, args, data_loader_name
                                                ),
            "trt_legacy":
            add_trt_legacy_runner,
        }[runner_arg]
        add_runner_func(script, args)

    add_comparator(script,
                   args,
                   data_loader_name=data_loader_name,
                   cmd_run=cmd_run)
    return str(script)
예제 #2
0
 def get_data_loader(self):
     script = Script()
     data_loader_name = self.add_to_script(script)
     if data_loader_name is None:  # All arguments are default
         from polygraphy.comparator import DataLoader
         return DataLoader()
     exec(str(script), globals(), locals())
     return locals()[data_loader_name]
예제 #3
0
    def build_script(self, args):
        script = Script(
            summary=generate_summary(self.makers[ModelArgs].model_file,
                                     args.runners, args.load_results))

        self.makers[LoggerArgs].add_to_script(script)

        data_loader_name = self.makers[DataLoaderArgs].add_to_script(script)

        for runner_arg in args.runners:
            add_runner_func = {
                "tf":
                self.makers[TfRunnerArgs].add_to_script,
                "onnxrt":
                self.makers[OnnxrtRunnerArgs].add_to_script,
                "onnxtf":
                self.makers[OnnxtfRunnerArgs].add_to_script,
                "trt":
                lambda script: self.makers[TrtRunnerArgs].add_to_script(
                    script, data_loader_name),
                "trt_legacy":
                self.makers[TrtLegacyArgs].add_to_script,
            }[runner_arg]
            add_runner_func(script)

        RESULTS_VAR_NAME = self.makers[ComparatorRunArgs].add_to_script(
            script, data_loader_name=data_loader_name)
        SUCCESS_VAR_NAME = self.makers[ComparatorCompareArgs].add_to_script(
            script, results_name=RESULTS_VAR_NAME)

        cmd_run = Inline("' '.join(sys.argv)")
        script.append_suffix(
            Script.format_str(
                '# Report Results\ncmd_run={cmd}\nif {success}:\n{tab}G_LOGGER.finish("PASSED | Command: {{}}".format(cmd_run))\nelse:\n{tab}G_LOGGER.error("FAILED | Command: {{}}".format(cmd_run))',
                cmd=cmd_run,
                success=SUCCESS_VAR_NAME,
                tab=Inline(constants.TAB)))
        script.append_suffix(
            "sys.exit(0 if {success} else 1)".format(success=SUCCESS_VAR_NAME))

        return str(script)
예제 #4
0
 def get_tf_loader(self):
     script = Script()
     loader_name = self.add_to_script(script)
     exec(str(script), globals(), locals())
     return locals()[loader_name]
예제 #5
0
 def get_trt_serialized_engine_loader(self):
     script = Script()
     loader_name = self.add_trt_serialized_engine_loader(script)
     exec(str(script), globals(), locals())
     return locals()[loader_name]
예제 #6
0
 def get_trt_config_loader(self, data_loader):
     script = Script()
     loader_name = self.add_trt_config_loader(
         script, data_loader_name="data_loader")
     exec(str(script), globals(), locals())
     return locals()[loader_name]
예제 #7
0
def get_trt_network_loader(args):
    script = Script()
    loader_name = add_trt_network_loader(script, args)
    exec(str(script), globals(), locals())
    return locals()[loader_name]
예제 #8
0
def get_onnx_model_loader(args):
    script = Script()
    loader_name = add_onnx_loader(script, args)
    exec(str(script), globals(), locals())
    return locals()[loader_name]