Exemple #1
0
    def test_default_args(self):
        def script_add(script, arg0=0, arg1=0):
            result_name = safe("result")
            script.append_suffix(safe("{:} = {:} + {:}", inline(result_name), arg0, arg1))
            return result_name

        assert args_util.run_script(script_add) == 0
        assert args_util.run_script(script_add, 1) == 1
        assert args_util.run_script(script_add, 1, 2) == 3
Exemple #2
0
    def get_data_loader(self, user_input_metadata=None):
        from polygraphy.comparator import DataLoader

        needs_invoke = False

        # run_script expects the callable to return just the variable name, but self.add_to_script
        # has 2 return values. We wrap it here to create a function with the right signature.
        def add_to_script_wrapper(script, *args, **kwargs):
            nonlocal needs_invoke
            name, needs_invoke = self._add_to_script(script, *args, **kwargs)
            return name

        data_loader = util.default(
            args_util.run_script(add_to_script_wrapper, user_input_metadata),
            DataLoader())
        if needs_invoke:
            data_loader = data_loader()
        return data_loader
Exemple #3
0
 def get_logger(self):
     return args_util.run_script(self.add_to_script)
Exemple #4
0
 def load_onnx(self):
     loader = args_util.run_script(self.add_onnx_loader)
     return loader()
Exemple #5
0
 def save_onnx(self, model, path=None):
     with util.TempAttrChange(self, "path", path):
         loader = args_util.run_script(self.add_save_onnx, model)
         return loader()
Exemple #6
0
    def create_config(self, builder, network):
        from polygraphy.backend.trt import CreateConfig

        loader = util.default(args_util.run_script(self.add_trt_config_loader),
                              CreateConfig())
        return loader(builder, network)
Exemple #7
0
 def load_serialized_engine(self):
     loader = args_util.run_script(self.add_trt_serialized_engine_loader)
     return loader()
Exemple #8
0
 def build_engine(self, network=None):
     loader = args_util.run_script(self.add_trt_build_engine_loader, network)
     return loader()
Exemple #9
0
 def save_engine(self, engine, path=None):
     with util.TempAttrChange(self, "path", path):
         loader = args_util.run_script(self.add_save_engine, engine)
         return loader()
Exemple #10
0
 def get_network_loader(self):
     return args_util.run_script(self.add_trt_network_loader)
Exemple #11
0
 def load_graph(self):
     loader = args_util.run_script(self.add_to_script)
     return loader()
Exemple #12
0
 def get_data_loader(self, user_input_metadata=None):
     from polygraphy.comparator import DataLoader
     return util.default(
         args_util.run_script(self.add_to_script, user_input_metadata),
         DataLoader())
Exemple #13
0
 def load_network(self):
     loader = args_util.run_script(self.add_trt_network_loader)
     return loader()