def test_sip_dict_arg_result_success(self): fabi = rt.FunctionAbi(self.device, self.htf, ATTRS_SIP_1LEVEL_DICT) self.assertEqual( "<FunctionAbi (Buffer<sint32[1x384]>, Buffer<sint32[1x384]>, Buffer<sint32[1x384]>) -> (Buffer<float32[1x384]>, Buffer<float32[1x384]>) SIP:'I53!D49!K10!input_ids_1K11!input_mask_2K12!segment_ids_0R39!D35!K11!end_logits_0K13!start_logits_1'>", repr(fabi)) input_ids = np.zeros((1, 384), dtype=np.int32) input_mask = np.zeros((1, 384), dtype=np.int32) segment_ids = np.zeros((1, 384), dtype=np.int32) f_args = fabi.pack_inputs(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids) self.assertEqual( "<VmVariantList(3): [HalBufferView(1x384:0x1000020), HalBufferView(1x384:0x1000020), HalBufferView(1x384:0x1000020)]>", repr(f_args)) f_results = fabi.allocate_results(f_args) logging.info("f_results: %s", f_results) self.assertEqual( "<VmVariantList(2): [HalBufferView(1x384:0x3000020), HalBufferView(1x384:0x3000020)]>", repr(f_results)) py_result = fabi.unpack_results(f_results) start_logits = py_result["start_logits"] end_logits = py_result["end_logits"] self.assertEqual(np.float32, start_logits.dtype) self.assertEqual(np.float32, end_logits.dtype) self.assertEqual((1, 384), start_logits.shape) self.assertEqual((1, 384), end_logits.shape)
def test_dynamic_alloc_result_success(self): fabi = rt.FunctionAbi(self.device, self.htf, ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) arg = np.zeros((10, 128, 64), dtype=np.float32) f_args = fabi.raw_pack_inputs([arg]) f_results = fabi.allocate_results(f_args, static_alloc=False) print(f_results) self.assertEqual("<VmVariantList(0): []>", repr(f_results))
def test_static_arg_static_dim_mismatch(self): fabi = rt.FunctionAbi(self.device, self.htf, ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) print(fabi) arg = np.zeros((10, 32, 64), dtype=np.float32) with self.assertRaisesRegex( ValueError, re.escape("Mismatched buffer dim (received: 32, expected: 128)")): fabi.raw_pack_inputs([arg])
def test_static_arg_dtype_mismatch(self): fabi = rt.FunctionAbi(self.device, self.htf, ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) logging.info("fabi: %s", fabi) arg = np.zeros((10, 128, 64), dtype=np.int32) with self.assertRaisesRegex( ValueError, re.escape("Mismatched buffer format (received: i, expected: f)")): fabi.pack_inputs(arg)
def test_static_arg_eltsize_mismatch(self): fabi = rt.FunctionAbi( self.device, self.htf, ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) logging.info("fabi: %s", fabi) arg = np.zeros((10, 128, 64), dtype=np.float64) with self.assertRaisesRegex( ValueError, re.escape( "Mismatched buffer item size (received: 8, expected: 4)")): fabi.raw_pack_inputs([arg])
def test_static_result_success(self): fabi = rt.FunctionAbi(self.device, self.htf, ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) arg = np.zeros((10, 128, 64), dtype=np.float32) f_args = fabi.raw_pack_inputs([arg]) f_results = fabi.allocate_results(f_args) print(f_results) self.assertEqual("<VmVariantList(1): [HalBufferView(32x8x64:0x1000020)]>", repr(f_results)) py_result, = fabi.raw_unpack_results(f_results) self.assertEqual(np.int32, py_result.dtype) self.assertEqual((32, 8, 64), py_result.shape)
def test_static_arg_success(self): fabi = rt.FunctionAbi(self.device, self.htf, ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1) print(fabi) self.assertEqual( "<FunctionAbi (Buffer<float32[10x128x64]>) -> " "(Buffer<sint32[32x8x64]>)>", repr(fabi)) self.assertEqual(1, fabi.raw_input_arity) self.assertEqual(1, fabi.raw_result_arity) arg = np.zeros((10, 128, 64), dtype=np.float32) packed = fabi.raw_pack_inputs([arg]) print(packed) self.assertEqual("<VmVariantList(1): [HalBufferView(10x128x64:0x3000020)]>", repr(packed))
def test_dynamic_arg_success(self): fabi = rt.FunctionAbi(self.device, self.htf, ATTRS_1ARG_FLOAT32_DYNX128X64_TO_SINT32_DYNX8X64_V1) logging.info("fabi: %s", fabi) self.assertEqual( "<FunctionAbi (Buffer<float32[?x128x64]>) -> " "(Buffer<sint32[?x8x64]>)>", repr(fabi)) self.assertEqual(1, fabi.raw_input_arity) self.assertEqual(1, fabi.raw_result_arity) arg = np.zeros((10, 128, 64), dtype=np.float32) packed = fabi.pack_inputs(arg) logging.info("packed: %s", packed) self.assertEqual("<VmVariantList(1): [HalBufferView(10x128x64:0x3000020)]>", repr(packed))
def test_dynamic_arg_success(self): fabi = rt.FunctionAbi( self.device, self.htf, ATTRS_1ARG_FLOAT32_DYNX128X64_TO_SINT32_DYNX8X64_V1) print(fabi) self.assertEqual( "<FunctionAbi (Buffer<float32[?x128x64]>) -> " "(Buffer<sint32[?x8x64]>)>", repr(fabi)) self.assertEqual(1, fabi.raw_input_arity) self.assertEqual(1, fabi.raw_result_arity) arg = np.zeros((10, 128, 64), dtype=np.float32) with self.assertRaisesRegex( NotImplementedError, "Dynamic argument dimensions not implemented"): unused_packed = fabi.raw_pack_inputs([arg])
def test_sip_linear_success(self): fabi = rt.FunctionAbi(self.device, self.htf, ATTRS_SIP_LINEAR_2ARG) self.assertEqual( "<FunctionAbi (Buffer<float32[1]>, Buffer<float32[1]>) -> (Buffer<float32[1]>) SIP:'I12!S9!k0_0k1_1R3!_0'>", repr(fabi)) arg0 = np.zeros((1,), dtype=np.float32) arg1 = np.zeros((1,), dtype=np.float32) f_args = fabi.pack_inputs(arg0, arg1) self.assertEqual( "<VmVariantList(2): [HalBufferView(1:0x3000020), HalBufferView(1:0x3000020)]>", repr(f_args)) f_results = fabi.allocate_results(f_args) logging.info("f_results: %s", f_results) self.assertEqual("<VmVariantList(1): [HalBufferView(1:0x3000020)]>", repr(f_results)) result = fabi.unpack_results(f_results) print("SINGLE RESULT:", result) self.assertEqual(np.float32, result.dtype) self.assertEqual((1,), result.shape)