Exemple #1
0
    def __call__(self):
        """
        Modifies a TensorRT ``INetworkDefinition``.

        Returns:
            trt.INetworkDefinition: The modified network.
        """
        ret, owns_network = misc.try_call(self._network)
        builder, network, parser = misc.unpack_args(ret, num=3)

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(
                    misc.FreeOnException([builder, network, parser]))

            if self.outputs == constants.MARK_ALL:
                trt_util.mark_layerwise(network)
            elif self.outputs is not None:
                trt_util.mark_outputs(network, self.outputs)

            if self.exclude_outputs is not None:
                trt_util.unmark_outputs(network, self.exclude_outputs)

            if parser is None:
                return builder, network
            return builder, network, parser
Exemple #2
0
    def __call__(self):
        builder = trt.Builder(get_trt_logger())
        network = builder.create_network()
        parser = trt.CaffeParser()

        parser.parse(deploy=self.deploy,
                     model=self.model,
                     network=network,
                     dtype=self.dtype)

        if self.outputs and self.outputs != constants.MARK_ALL:
            trt_util.mark_outputs(network, self.outputs)

        return builder, network, parser, self.batch_size
Exemple #3
0
    def __call__(self):
        """
        Modifies a TensorRT ``INetworkDefinition``.

        Returns:
            trt.INetworkDefinition: The modified network.
        """
        ret, _ = misc.try_call(self._network)
        builder, network, parser = misc.unpack_args(ret, num=3)

        if self.outputs == constants.MARK_ALL:
            trt_util.mark_layerwise(network)
        elif self.outputs is not None:
            trt_util.mark_outputs(network, self.outputs)

        if self.exclude_outputs is not None:
            trt_util.unmark_outputs(network, self.exclude_outputs)

        if parser is not None:
            return builder, network, parser
        else:
            return builder, network
Exemple #4
0
    def call_impl(self):
        """
        Returns:
            trt.INetworkDefinition: The modified network.
        """
        ret, owns_network = util.invoke_if_callable(self._network)
        builder, network, parser = util.unpack_args(ret, num=3)

        with contextlib.ExitStack() as stack:
            if owns_network:
                stack.enter_context(util.FreeOnException([builder, network, parser]))

            if self.outputs == constants.MARK_ALL:
                trt_util.mark_layerwise(network)
            elif self.outputs is not None:
                trt_util.mark_outputs(network, self.outputs)

            if self.exclude_outputs is not None:
                trt_util.unmark_outputs(network, self.exclude_outputs)

            if parser is None:
                return builder, network
            return builder, network, parser