def try_permute(arr, shape): try: perm = FormatManager.permutation(FormatManager.determine_format(arr.shape), FormatManager.determine_format(shape)) G_LOGGER.verbose("Permuting shape: {:} using permutation {:}".format(arr.shape, perm)) arr = np.transpose(arr, perm) except Exception as err: # FormatManager may not recognize the format or be able generate the permutation for the format combination G_LOGGER.extra_verbose("Skipping permutation due to {:}".format(err)) return arr
def __call__(self): uff_model, input_names, input_shapes, output_names = self.uff_loader() builder = trt.Builder(TRT_LOGGER) network = builder.create_network() parser = trt.UffParser() # Input names should come from the converter, as a preprocessing script may have been applied to the frozen model. for name, shape in zip(input_names, input_shapes): # Default order is NCHW, only set to NHWC if we're reasonably certain that it is. input_order = self.uff_order if not self.uff_order: input_order = trt.UffInputOrder.NCHW if FormatManager.determine_format(shape) == DataFormat.NHWC: input_order = trt.UffInputOrder.NHWC shape = shape[1:] G_LOGGER.verbose( "Registering UFF input: {:} with shape: {:} and input order: {:}" .format(name, shape, input_order)) parser.register_input(name, shape, input_order) if output_names and output_names != constants.MARK_ALL: for name in output_names: G_LOGGER.verbose("Registering UFF output: " + str(name)) parser.register_output(name) G_LOGGER.info( "Parsing UFF model with inputs: {:} and outputs: {:}".format( input_names, output_names)) success = parser.parse_buffer(uff_model, network) if not success: G_LOGGER.critical("Could not parse UFF correctly") return builder, network, parser, input_shapes[0][0]
def test_format_deduction(test_case): assert test_case.format == FormatManager.determine_format(test_case.shape)