def test_default_args(self): def script_add(script, arg0=0, arg1=0): result_name = safe("result") script.append_suffix(safe("{:} = {:} + {:}", inline(result_name), arg0, arg1)) return result_name assert args_util.run_script(script_add) == 0 assert args_util.run_script(script_add, 1) == 1 assert args_util.run_script(script_add, 1, 2) == 3
def get_data_loader(self, user_input_metadata=None): from polygraphy.comparator import DataLoader needs_invoke = False # run_script expects the callable to return just the variable name, but self.add_to_script # has 2 return values. We wrap it here to create a function with the right signature. def add_to_script_wrapper(script, *args, **kwargs): nonlocal needs_invoke name, needs_invoke = self._add_to_script(script, *args, **kwargs) return name data_loader = util.default( args_util.run_script(add_to_script_wrapper, user_input_metadata), DataLoader()) if needs_invoke: data_loader = data_loader() return data_loader
def get_logger(self): return args_util.run_script(self.add_to_script)
def load_onnx(self): loader = args_util.run_script(self.add_onnx_loader) return loader()
def save_onnx(self, model, path=None): with util.TempAttrChange(self, "path", path): loader = args_util.run_script(self.add_save_onnx, model) return loader()
def create_config(self, builder, network): from polygraphy.backend.trt import CreateConfig loader = util.default(args_util.run_script(self.add_trt_config_loader), CreateConfig()) return loader(builder, network)
def load_serialized_engine(self): loader = args_util.run_script(self.add_trt_serialized_engine_loader) return loader()
def build_engine(self, network=None): loader = args_util.run_script(self.add_trt_build_engine_loader, network) return loader()
def save_engine(self, engine, path=None): with util.TempAttrChange(self, "path", path): loader = args_util.run_script(self.add_save_engine, engine) return loader()
def get_network_loader(self): return args_util.run_script(self.add_trt_network_loader)
def load_graph(self): loader = args_util.run_script(self.add_to_script) return loader()
def get_data_loader(self, user_input_metadata=None): from polygraphy.comparator import DataLoader return util.default( args_util.run_script(self.add_to_script, user_input_metadata), DataLoader())
def load_network(self): loader = args_util.run_script(self.add_trt_network_loader) return loader()