def test_invalid_parameters(in_dtype, out_dtype, zero_point, scale, out_zero_point, out_scale): model = make_model([1, 16, 16, 3], in_dtype, out_dtype, zero_point, scale, out_zero_point, out_scale) orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) assert_no_external_function(cmsisnn_mod)
def test_invalid_batch_size(op): model = make_model( pool_op=op, shape=(2, 28, 28, 12), ) orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) assert_no_external_function(cmsisnn_mod)
def test_invalid_parameters( in_dtype, kernel_dtype, kernel_zero_point, ): ifm_shape = (1, 28, 28, 12) out_channels = 2 input_scale = 1 input_zero_point = 24 kernel_scale = [0.11, 0.0237] in_min, in_max = get_range_for_dtype_str(in_dtype) kernel_layout = "HWIO" kernel_shape = [3, 3, ifm_shape[3], out_channels] output_scale, output_zero_point = get_conv2d_qnn_params( kernel_shape, input_scale, input_zero_point, kernel_scale, kernel_zero_point, in_dtype, kernel_dtype, in_dtype, False, ) model, params = make_model( shape=ifm_shape, kernel_shape=kernel_shape, input_zero_point=input_zero_point, input_scale=input_scale, kernel_zero_point=kernel_zero_point, kernel_scale=kernel_scale, output_zero_point=output_zero_point, output_scale=output_scale, padding="SAME", strides=(1, 1), dilation=(1, 1), groups=1, dtype=in_dtype, kernel_dtype=kernel_dtype, out_channels=out_channels, weight_format=kernel_layout, enable_bias=True, relu_type="NONE", ) orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) assert_no_external_function(cmsisnn_mod)
def test_invalid_parameters(): model = make_model( pool_op=relay.nn.avg_pool2d, shape=(1, 28, 28, 12), pool_size=(1, 1), strides=(1, 1), padding="VALID", dtype="uint8", scale=1, zero_point=-33, relu_type="RELU", ) orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) assert_no_external_function(cmsisnn_mod)
def test_both_scalar_inputs_int8(op, ): input_scale = 0.256 input_zero_point = 33 dtype = "int8" model = make_model( op, generate_scalar_constant(), generate_scalar_constant(), input_scale, input_zero_point, input_scale, input_zero_point, ) orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) assert_no_external_function(cmsisnn_mod)
def test_invalid_parameters( op, input_dtype, ): input_scale = 0.256 input_zero_point = 33 model = make_model( op, generate_variable("input_0", input_dtype), generate_variable("input_1", input_dtype), input_scale, input_zero_point, input_scale, input_zero_point, ) orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) assert_no_external_function(cmsisnn_mod)
def test_invalid_parameters( in_dtype, kernel_dtype, kernel_zero_point, ): in_shape = (2, 28) out_channels = 2 input_scale = 1 input_zero_point = 24 kernel_scale = [0.11, 0.0237] in_min, in_max = get_range_for_dtype_str(in_dtype) kernel_shape = [out_channels, in_shape[1]] conv2d_kernel_shape = [1, 1, kernel_shape[0], kernel_shape[1]] output_scale, output_zero_point = get_conv2d_qnn_params( conv2d_kernel_shape, input_scale, input_zero_point, kernel_scale, kernel_zero_point, in_dtype, kernel_dtype, in_dtype, ) model, params = make_model( in_shape=in_shape, kernel_shape=kernel_shape, input_zero_point=input_zero_point, kernel_zero_point=kernel_zero_point, input_scale=input_scale, kernel_scale=kernel_scale, output_zero_point=output_zero_point, output_scale=output_scale, dtype=in_dtype, kernel_dtype=kernel_dtype, out_channels=out_channels, enable_bias=True, ) orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params) # validate pattern matching assert_no_external_function(cmsisnn_mod)
def test_invalid_layout(op): model = make_model(pool_op=op, layout="NCHW") orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) assert_no_external_function(cmsisnn_mod)
def test_invalid_datatype(op): model = make_model(pool_op=op, dtype="int64") orig_mod = make_module(model) cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) assert_no_external_function(cmsisnn_mod)