예제 #1
0
 def test_sip_dict_arg_result_success(self):
   fabi = rt.FunctionAbi(self.device, self.htf, ATTRS_SIP_1LEVEL_DICT)
   self.assertEqual(
       "<FunctionAbi (Buffer<sint32[1x384]>, Buffer<sint32[1x384]>, Buffer<sint32[1x384]>) -> (Buffer<float32[1x384]>, Buffer<float32[1x384]>) SIP:'I53!D49!K10!input_ids_1K11!input_mask_2K12!segment_ids_0R39!D35!K11!end_logits_0K13!start_logits_1'>",
       repr(fabi))
   input_ids = np.zeros((1, 384), dtype=np.int32)
   input_mask = np.zeros((1, 384), dtype=np.int32)
   segment_ids = np.zeros((1, 384), dtype=np.int32)
   f_args = fabi.pack_inputs(input_ids=input_ids,
                             input_mask=input_mask,
                             segment_ids=segment_ids)
   self.assertEqual(
       "<VmVariantList(3): [HalBufferView(1x384:0x1000020), HalBufferView(1x384:0x1000020), HalBufferView(1x384:0x1000020)]>",
       repr(f_args))
   f_results = fabi.allocate_results(f_args)
   logging.info("f_results: %s", f_results)
   self.assertEqual(
       "<VmVariantList(2): [HalBufferView(1x384:0x3000020), HalBufferView(1x384:0x3000020)]>",
       repr(f_results))
   py_result = fabi.unpack_results(f_results)
   start_logits = py_result["start_logits"]
   end_logits = py_result["end_logits"]
   self.assertEqual(np.float32, start_logits.dtype)
   self.assertEqual(np.float32, end_logits.dtype)
   self.assertEqual((1, 384), start_logits.shape)
   self.assertEqual((1, 384), end_logits.shape)
예제 #2
0
 def test_dynamic_alloc_result_success(self):
   fabi = rt.FunctionAbi(self.device, self.htf,
                         ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
   arg = np.zeros((10, 128, 64), dtype=np.float32)
   f_args = fabi.raw_pack_inputs([arg])
   f_results = fabi.allocate_results(f_args, static_alloc=False)
   print(f_results)
   self.assertEqual("<VmVariantList(0): []>", repr(f_results))
예제 #3
0
 def test_static_arg_static_dim_mismatch(self):
   fabi = rt.FunctionAbi(self.device, self.htf,
                         ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
   print(fabi)
   arg = np.zeros((10, 32, 64), dtype=np.float32)
   with self.assertRaisesRegex(
       ValueError,
       re.escape("Mismatched buffer dim (received: 32, expected: 128)")):
     fabi.raw_pack_inputs([arg])
예제 #4
0
 def test_static_arg_dtype_mismatch(self):
   fabi = rt.FunctionAbi(self.device, self.htf,
                         ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
   logging.info("fabi: %s", fabi)
   arg = np.zeros((10, 128, 64), dtype=np.int32)
   with self.assertRaisesRegex(
       ValueError,
       re.escape("Mismatched buffer format (received: i, expected: f)")):
     fabi.pack_inputs(arg)
예제 #5
0
 def test_static_arg_eltsize_mismatch(self):
     fabi = rt.FunctionAbi(
         self.device, self.htf,
         ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
     logging.info("fabi: %s", fabi)
     arg = np.zeros((10, 128, 64), dtype=np.float64)
     with self.assertRaisesRegex(
             ValueError,
             re.escape(
                 "Mismatched buffer item size (received: 8, expected: 4)")):
         fabi.raw_pack_inputs([arg])
예제 #6
0
 def test_static_result_success(self):
   fabi = rt.FunctionAbi(self.device, self.htf,
                         ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
   arg = np.zeros((10, 128, 64), dtype=np.float32)
   f_args = fabi.raw_pack_inputs([arg])
   f_results = fabi.allocate_results(f_args)
   print(f_results)
   self.assertEqual("<VmVariantList(1): [HalBufferView(32x8x64:0x1000020)]>",
                    repr(f_results))
   py_result, = fabi.raw_unpack_results(f_results)
   self.assertEqual(np.int32, py_result.dtype)
   self.assertEqual((32, 8, 64), py_result.shape)
예제 #7
0
  def test_static_arg_success(self):
    fabi = rt.FunctionAbi(self.device, self.htf,
                          ATTRS_1ARG_FLOAT32_10X128X64_TO_SINT32_32X8X64_V1)
    print(fabi)
    self.assertEqual(
        "<FunctionAbi (Buffer<float32[10x128x64]>) -> "
        "(Buffer<sint32[32x8x64]>)>", repr(fabi))
    self.assertEqual(1, fabi.raw_input_arity)
    self.assertEqual(1, fabi.raw_result_arity)

    arg = np.zeros((10, 128, 64), dtype=np.float32)
    packed = fabi.raw_pack_inputs([arg])
    print(packed)
    self.assertEqual("<VmVariantList(1): [HalBufferView(10x128x64:0x3000020)]>",
                     repr(packed))
예제 #8
0
  def test_dynamic_arg_success(self):
    fabi = rt.FunctionAbi(self.device, self.htf,
                          ATTRS_1ARG_FLOAT32_DYNX128X64_TO_SINT32_DYNX8X64_V1)
    logging.info("fabi: %s", fabi)
    self.assertEqual(
        "<FunctionAbi (Buffer<float32[?x128x64]>) -> "
        "(Buffer<sint32[?x8x64]>)>", repr(fabi))
    self.assertEqual(1, fabi.raw_input_arity)
    self.assertEqual(1, fabi.raw_result_arity)

    arg = np.zeros((10, 128, 64), dtype=np.float32)
    packed = fabi.pack_inputs(arg)
    logging.info("packed: %s", packed)
    self.assertEqual("<VmVariantList(1): [HalBufferView(10x128x64:0x3000020)]>",
                     repr(packed))
예제 #9
0
    def test_dynamic_arg_success(self):
        fabi = rt.FunctionAbi(
            self.device, self.htf,
            ATTRS_1ARG_FLOAT32_DYNX128X64_TO_SINT32_DYNX8X64_V1)
        print(fabi)
        self.assertEqual(
            "<FunctionAbi (Buffer<float32[?x128x64]>) -> "
            "(Buffer<sint32[?x8x64]>)>", repr(fabi))
        self.assertEqual(1, fabi.raw_input_arity)
        self.assertEqual(1, fabi.raw_result_arity)

        arg = np.zeros((10, 128, 64), dtype=np.float32)
        with self.assertRaisesRegex(
                NotImplementedError,
                "Dynamic argument dimensions not implemented"):
            unused_packed = fabi.raw_pack_inputs([arg])
예제 #10
0
 def test_sip_linear_success(self):
   fabi = rt.FunctionAbi(self.device, self.htf, ATTRS_SIP_LINEAR_2ARG)
   self.assertEqual(
       "<FunctionAbi (Buffer<float32[1]>, Buffer<float32[1]>) -> (Buffer<float32[1]>) SIP:'I12!S9!k0_0k1_1R3!_0'>",
       repr(fabi))
   arg0 = np.zeros((1,), dtype=np.float32)
   arg1 = np.zeros((1,), dtype=np.float32)
   f_args = fabi.pack_inputs(arg0, arg1)
   self.assertEqual(
       "<VmVariantList(2): [HalBufferView(1:0x3000020), HalBufferView(1:0x3000020)]>",
       repr(f_args))
   f_results = fabi.allocate_results(f_args)
   logging.info("f_results: %s", f_results)
   self.assertEqual("<VmVariantList(1): [HalBufferView(1:0x3000020)]>",
                    repr(f_results))
   result = fabi.unpack_results(f_results)
   print("SINGLE RESULT:", result)
   self.assertEqual(np.float32, result.dtype)
   self.assertEqual((1,), result.shape)