示例#1
0
def build_default_profile(builder, network, default_shape_value=None):
    default_shape_value = misc.default_value(default_shape_value, DEFAULT_SHAPE_VALUE)

    def override_shape(shape):
        return tuple([default_shape_value if misc.is_dimension_dynamic(dim) else dim for dim in shape])

    trt_profile = builder.create_optimization_profile()
    for idx in range(network.num_inputs):
        inp = network.get_input(idx)

        with G_LOGGER.verbosity(G_LOGGER.CRITICAL): # WAR for spam from TRT
            is_shape_tensor = inp.is_shape_tensor

        if is_shape_tensor:
            rank = inp.shape[0]
            shape = (default_shape_value, ) * rank
            G_LOGGER.warning("Input shape-tensor: {:24} | Will use input values: {:} in profile.\n"
                             "If this is incorrect, please provide a profile "
                             "that sets the values for this input shape-tensor.".format(inp.name, shape, rank), mode=LogMode.ONCE)
            trt_profile.set_shape_input(inp.name, shape, shape, shape)
        else:
            shape = override_shape(inp.shape)
            if override_shape(inp.shape) != inp.shape:
                G_LOGGER.warning("Input tensor: {:24} | Will use shape: {:} in profile (tensor shape is: {:}).\n"
                                 "If this is incorrect, please provide a profile "
                                 "that sets the shape for this input tensor.".format(inp.name, shape, inp.shape), mode=LogMode.ONCE)
            trt_profile.set_shape(inp.name, shape, shape, shape)
    return check_profile(trt_profile)
示例#2
0
def build_profile(builder, network, profile):
    trt_profile = builder.create_optimization_profile()
    unused_keys = set(profile.keys())
    for idx in range(network.num_inputs):
        inp = network.get_input(idx)
        if inp.name in unused_keys:
            unused_keys.remove(inp.name)

        with G_LOGGER.verbosity(): # WAR for spam from TRT
            is_shape_tensor = inp.is_shape_tensor

        if is_shape_tensor:
            if inp.name in profile:
                shapes = profile[inp.name]
                trt_profile.set_shape_input(inp.name, shapes.min, shapes.opt, shapes.max)
                G_LOGGER.extra_verbose("Input shape-tensor: {:24} | Setting values to min: {:}, opt: {:}, max: {:}".format(inp.name, shapes.min, shapes.opt, shapes.max))
            else:
                G_LOGGER.warning("input shape-tensor: {:24} | No values provided. Assuming this is not a dynamic shape-tensor.".format(inp.name), mode=LogMode.ONCE)
        elif misc.is_shape_dynamic(inp.shape):
            shapes = profile[inp.name]
            trt_profile.set_shape(inp.name, shapes.min, shapes.opt, shapes.max)
            G_LOGGER.extra_verbose("Input tensor: {:24} | Setting shape to min: {:}, opt: {:}, max: {:}".format(inp.name, shapes.min, shapes.opt, shapes.max))

    if unused_keys:
        G_LOGGER.warning("Some inputs provided in the profile were unused: {:}".format(list(unused_keys)))

    return check_profile(trt_profile)
示例#3
0
    def test_non_matching_outputs(self):
        iter_result0 = IterationResult(
            outputs={"output": np.zeros((2, 2, 2, 2), dtype=np.float32)})
        iter_result1 = IterationResult(
            outputs={"output": np.ones((2, 2, 2, 2), dtype=np.float32)})

        compare_func = CompareFunc.basic_compare_func()

        with G_LOGGER.verbosity(G_LOGGER.ULTRA_VERBOSE):
            acc = compare_func(iter_result0, iter_result1)

        assert not acc["output"]