コード例 #1
0
    def RunTest(self, run_params):
        if not self.ShouldRunTest(run_params):
            return
        assert run_params.precision_mode in PRECISION_MODES
        np.random.seed(12345)

        params = self._GetParamsCached()
        input_gdef = params.gdef
        input_dtypes = {}
        for node in input_gdef.node:
            if self._ToString(node.name) in params.input_names:
                assert self._ToString(node.op) == "Placeholder"
                input_dtypes[self._ToString(node.name)] = (dtypes.as_dtype(
                    node.attr["dtype"].type).as_numpy_dtype())
        assert len(params.input_names) == len(input_dtypes)

        input_data = []
        for i in range(len(params.input_names)):
            dtype = input_dtypes[params.input_names[i]]
            # Multiply the input by some constant to avoid all zeros input for integer
            # types.
            scale = 10.0 if np.issubdtype(dtype, np.integer) else 1.0
            dims = params.input_dims[i]
            input_data.append(
                (scale * np.random.random_sample(dims)).astype(dtype))
        self._VerifyGraphDef(run_params, input_gdef, GraphState.ORIGINAL)

        # Get reference result without running trt.
        config_no_trt = self._GetConfigProto(run_params, GraphState.ORIGINAL)
        logging.info("Running original graph w/o trt, config:\n%s",
                     str(config_no_trt))
        ref_result = self._RunGraph(run_params, input_gdef, input_data,
                                    config_no_trt, GraphState.ORIGINAL)

        # Run calibration if necessary.
        if IsQuantizationMode(run_params.precision_mode):

            calib_config = self._GetConfigProto(run_params,
                                                GraphState.CALIBRATE)
            logging.info("Running calibration graph, config:\n%s",
                         str(calib_config))
            if run_params.use_optimizer:
                result = self._RunCalibration(run_params, input_gdef,
                                              input_data, calib_config)
            else:
                calib_gdef = self._GetTrtGraphDef(run_params, input_gdef)
                self._VerifyGraphDef(run_params, calib_gdef,
                                     GraphState.CALIBRATE)
                result = self._RunCalibration(run_params, calib_gdef,
                                              input_data, calib_config)
            infer_gdef = trt_convert.calib_graph_to_infer_graph(
                calib_gdef, run_params.dynamic_engine)
            self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE)

            self.assertAllClose(
                ref_result,
                result,
                atol=self.ExpectedAbsoluteTolerance(run_params),
                rtol=self.ExpectedRelativeTolerance(run_params))
        else:
            infer_gdef = input_gdef

        # Run inference.
        infer_config = self._GetConfigProto(run_params, GraphState.INFERENCE)
        logging.info("Running final inference graph, config:\n%s",
                     str(infer_config))
        if not run_params.use_optimizer:
            infer_gdef = self._GetTrtGraphDef(run_params, infer_gdef)
            self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE)

        result = self._RunGraph(run_params, infer_gdef, input_data,
                                infer_config, GraphState.INFERENCE)
        self.assertAllClose(ref_result,
                            result,
                            atol=self.ExpectedAbsoluteTolerance(run_params),
                            rtol=self.ExpectedRelativeTolerance(run_params))
コード例 #2
0
  def RunTest(self, params, use_optimizer, precision_mode,
              dynamic_infer_engine, dynamic_calib_engine):
    assert precision_mode in PRECISION_MODES
    input_data = [np.random.random_sample(dims) for dims in params.input_dims]
    input_gdef = params.gdef
    self._VerifyGraphDef(params, input_gdef)

    # Get reference result without running trt.
    config_no_trt = self._GetConfigProto(params, False)
    logging.info("Running original graph w/o trt, config:\n%s",
                 str(config_no_trt))
    ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt)

    # Run calibration if necessary.
    if _IsQuantizationMode(precision_mode):

      calib_config = self._GetConfigProto(params, use_optimizer, precision_mode,
                                          dynamic_calib_engine)
      logging.info("Running calibration graph, config:\n%s", str(calib_config))
      if use_optimizer:
        self.assertTrue(False)
        # TODO(aaroey): uncomment this and get infer_gdef when this mode is
        # supported.
        # result = self._RunCalibration(params, input_gdef, input_data,
        #                               calib_config)
      else:
        calib_gdef = self._GetTrtGraphDef(params, input_gdef, precision_mode,
                                          dynamic_calib_engine)
        self._VerifyGraphDef(params, calib_gdef, precision_mode, False,
                             dynamic_calib_engine)
        result = self._RunCalibration(params, calib_gdef, input_data,
                                      calib_config)
        infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
        self._VerifyGraphDef(params, infer_gdef, precision_mode, True,
                             dynamic_calib_engine)

      self.assertAllClose(
          ref_result,
          result,
          atol=params.allclose_atol,
          rtol=params.allclose_rtol)
    else:
      infer_gdef = input_gdef

    # Run inference.
    infer_config = self._GetConfigProto(params, use_optimizer, precision_mode,
                                        dynamic_infer_engine)
    logging.info("Running final inference graph, config:\n%s",
                 str(infer_config))
    if use_optimizer:
      result = self._RunGraph(params, infer_gdef, input_data, infer_config)
    else:
      trt_infer_gdef = self._GetTrtGraphDef(params, infer_gdef, precision_mode,
                                            dynamic_infer_engine)
      self._VerifyGraphDef(params, trt_infer_gdef, precision_mode, True,
                           dynamic_infer_engine)
      result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config)

    self.assertAllClose(
        ref_result,
        result,
        atol=params.allclose_atol,
        rtol=params.allclose_rtol)
コード例 #3
0
  def RunTest(self, run_params):
    if not self.ShouldRunTest(run_params):
      return
    assert run_params.precision_mode in PRECISION_MODES
    np.random.seed(12345)

    params = self._GetParamsCached()
    input_gdef = params.gdef
    input_dtypes = {}
    for node in input_gdef.node:
      if self._ToString(node.name) in params.input_names:
        assert self._ToString(node.op) == "Placeholder"
        input_dtypes[self._ToString(node.name)] = (
            dtypes.as_dtype(node.attr["dtype"].type).as_numpy_dtype())
    assert len(params.input_names) == len(input_dtypes)

    input_data = []
    for i in range(len(params.input_names)):
      dtype = input_dtypes[params.input_names[i]]
      # Multiply the input by some constant to avoid all zeros input for integer
      # types.
      scale = 10.0 if np.issubdtype(dtype, np.integer) else 1.0
      dims = params.input_dims[i]
      input_data.append((scale * np.random.random_sample(dims)).astype(dtype))
    self._VerifyGraphDef(run_params, input_gdef, GraphState.ORIGINAL)

    # Get reference result without running trt.
    config_no_trt = self._GetConfigProto(run_params, GraphState.ORIGINAL)
    logging.info("Running original graph w/o trt, config:\n%s",
                 str(config_no_trt))
    ref_result = self._RunGraph(run_params, input_gdef, input_data,
                                config_no_trt, GraphState.ORIGINAL)

    # Run calibration if necessary.
    if IsQuantizationMode(run_params.precision_mode):

      calib_config = self._GetConfigProto(run_params, GraphState.CALIBRATE)
      logging.info("Running calibration graph, config:\n%s", str(calib_config))
      if run_params.use_optimizer:
        result = self._RunCalibration(run_params, input_gdef, input_data,
                                      calib_config)
      else:
        calib_gdef = self._GetTrtGraphDef(run_params, input_gdef)
        self._VerifyGraphDef(run_params, calib_gdef, GraphState.CALIBRATE)
        result = self._RunCalibration(run_params, calib_gdef, input_data,
                                      calib_config)
      infer_gdef = trt_convert.calib_graph_to_infer_graph(
          calib_gdef, run_params.dynamic_engine)
      self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE)

      self.assertAllClose(
          ref_result,
          result,
          atol=self.ExpectedAbsoluteTolerance(run_params),
          rtol=self.ExpectedRelativeTolerance(run_params))
    else:
      infer_gdef = input_gdef

    # Run inference.
    infer_config = self._GetConfigProto(run_params, GraphState.INFERENCE)
    logging.info("Running final inference graph, config:\n%s",
                 str(infer_config))
    if not run_params.use_optimizer:
      infer_gdef = self._GetTrtGraphDef(run_params, infer_gdef)
      self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE)

    result = self._RunGraph(run_params, infer_gdef, input_data, infer_config,
                            GraphState.INFERENCE)
    self.assertAllClose(
        ref_result,
        result,
        atol=self.ExpectedAbsoluteTolerance(run_params),
        rtol=self.ExpectedRelativeTolerance(run_params))
コード例 #4
0
    def RunTest(self, params, run_params):
        assert run_params.precision_mode in PRECISION_MODES
        input_data = [
            np.random.random_sample(dims) for dims in params.input_dims
        ]
        input_gdef = params.gdef
        self._VerifyGraphDef(params, run_params, input_gdef,
                             GraphState.ORIGINAL)

        # Get reference result without running trt.
        config_no_trt = self._GetConfigProto(params, run_params,
                                             GraphState.ORIGINAL)
        logging.info("Running original graph w/o trt, config:\n%s",
                     str(config_no_trt))
        ref_result = self._RunGraph(params, input_gdef, input_data,
                                    config_no_trt, GraphState.ORIGINAL)

        # Run calibration if necessary.
        if _IsQuantizationMode(run_params.precision_mode):

            calib_config = self._GetConfigProto(params, run_params,
                                                GraphState.CALIBRATE)
            logging.info("Running calibration graph, config:\n%s",
                         str(calib_config))
            if run_params.use_optimizer:
                result = self._RunCalibration(params, input_gdef, input_data,
                                              calib_config)
            else:
                calib_gdef = self._GetTrtGraphDef(params, run_params,
                                                  input_gdef)
                self._VerifyGraphDef(params, run_params, calib_gdef,
                                     GraphState.CALIBRATE)
                result = self._RunCalibration(params, calib_gdef, input_data,
                                              calib_config)
            infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef)
            self._VerifyGraphDef(params, run_params, infer_gdef,
                                 GraphState.INFERENCE)

            self.assertAllClose(ref_result,
                                result,
                                atol=params.allclose_atol,
                                rtol=params.allclose_rtol)
        else:
            infer_gdef = input_gdef

        # Run inference.
        infer_config = self._GetConfigProto(params, run_params,
                                            GraphState.INFERENCE)
        logging.info("Running final inference graph, config:\n%s",
                     str(infer_config))
        if run_params.use_optimizer:
            result = self._RunGraph(params, infer_gdef, input_data,
                                    infer_config, GraphState.INFERENCE)
        else:
            trt_infer_gdef = self._GetTrtGraphDef(params, run_params,
                                                  infer_gdef)
            self._VerifyGraphDef(params, run_params, trt_infer_gdef,
                                 GraphState.INFERENCE)
            result = self._RunGraph(params, trt_infer_gdef, input_data,
                                    infer_config, GraphState.INFERENCE)

        self.assertAllClose(ref_result,
                            result,
                            atol=params.allclose_atol,
                            rtol=params.allclose_rtol)