Example #1
0
 def try_permute(arr, shape):
     try:
         perm = FormatManager.permutation(FormatManager.determine_format(arr.shape), FormatManager.determine_format(shape))
         G_LOGGER.verbose("Permuting shape: {:} using permutation {:}".format(arr.shape, perm))
         arr = np.transpose(arr, perm)
     except Exception as err:
         # FormatManager may not recognize the format or be able generate the permutation for the format combination
         G_LOGGER.extra_verbose("Skipping permutation due to {:}".format(err))
     return arr
Example #2
0
    def __call__(self):
        uff_model, input_names, input_shapes, output_names = self.uff_loader()

        builder = trt.Builder(TRT_LOGGER)
        network = builder.create_network()
        parser = trt.UffParser()
        # Input names should come from the converter, as a preprocessing script may have been applied to the frozen model.
        for name, shape in zip(input_names, input_shapes):
            # Default order is NCHW, only set to NHWC if we're reasonably certain that it is.
            input_order = self.uff_order
            if not self.uff_order:
                input_order = trt.UffInputOrder.NCHW
                if FormatManager.determine_format(shape) == DataFormat.NHWC:
                    input_order = trt.UffInputOrder.NHWC
            shape = shape[1:]
            G_LOGGER.verbose(
                "Registering UFF input: {:} with shape: {:} and input order: {:}"
                .format(name, shape, input_order))
            parser.register_input(name, shape, input_order)

        if output_names and output_names != constants.MARK_ALL:
            for name in output_names:
                G_LOGGER.verbose("Registering UFF output: " + str(name))
                parser.register_output(name)

        G_LOGGER.info(
            "Parsing UFF model with inputs: {:} and outputs: {:}".format(
                input_names, output_names))
        success = parser.parse_buffer(uff_model, network)
        if not success:
            G_LOGGER.critical("Could not parse UFF correctly")
        return builder, network, parser, input_shapes[0][0]
Example #3
0
def test_format_deduction(test_case):
    assert test_case.format == FormatManager.determine_format(test_case.shape)