Beispiel #1
0
 def parse(self, args):
     self.ckpt = args_util.get(args, "ckpt")
     self.outputs = args_util.get_outputs(args, "tf_outputs")
     self.save_pb = args_util.get(args, "save_pb")
     self.save_tensorboard = args_util.get(args, "save_tensorboard")
     self.freeze_graph = args_util.get(args, "freeze_graph")
     self.tftrt = args_util.get(args, "tftrt")
     self.minimum_segment_size = args_util.get(args, "minimum_segment_size")
     self.dynamic_op = args_util.get(args, "dynamic_op")
Beispiel #2
0
    def parse(self, args):
        self.trt_outputs = args_util.get_outputs(args, "trt_outputs")
        self.caffe_model = args_util.get(args, "caffe_model")
        self.batch_size = args_util.get(args, "batch_size")
        self.save_uff = args_util.get(args, "save_uff")
        self.uff_order = args_util.get(args, "uff_order")
        self.preprocessor = args_util.get(args, "preprocessor")

        self.calibration_cache = args_util.get(args, "calibration_cache")
        calib_base = args_util.get(args, "calibration_base_class")
        self.calibration_base_class = None
        if calib_base is not None:
            calib_base = safe(assert_identifier(calib_base))
            self.calibration_base_class = inline(safe("trt.{:}", inline(calib_base)))

        self.quantile = args_util.get(args, "quantile")
        self.regression_cutoff = args_util.get(args, "regression_cutoff")

        self.use_dla = args_util.get(args, "use_dla")
        self.allow_gpu_fallback = args_util.get(args, "allow_gpu_fallback")
Beispiel #3
0
 def parse(self, args):
     self.outputs = args_util.get_outputs(args, "onnx_outputs")
     self.exclude_outputs = args_util.get(args, "onnx_exclude_outputs")
     self.load_external_data = args_util.get(args, "load_external_data")
Beispiel #4
0
 def parse(self, args):
     self.outputs = args_util.get_outputs(args, "trt_outputs")
     self.explicit_precision = args_util.get(args, "explicit_precision")
     self.exclude_outputs = args_util.get(args, "trt_exclude_outputs")
     self.trt_network_func_name = args_util.get(args, "trt_network_func_name")