Exemplo n.º 1
0
 def subtractTensor(self, tensor, nthreads=0, testing=False):
     if nthreads == 0:
         print(
             "Default Cases not learned yet.  Make sure you spedify your run type."
         )
     if nthreads == 1: type = "serial"
     else: type = "CPU"
     if type == "serial":
         result = COOTensor()
         PTI.ptiSparseTensorSub(result.address, self.address,
                                tensor.address)
         if not testing: return result
         else: PTI.ptiFreeSparseTensor(result.address)
     elif type == "CPU":
         if nthreads == 0: nthreads = self.nthreads
         PTI.ptiSparseTensorSubOMP(tensor.address, self.address, nthreads)
         if not testing: return tensor
     else: exit("Invalid Type")
Exemplo n.º 2
0
 def dotMulTensor(self, tensor, type="default", testing=False):
     result = COOTensor()
     if type == "default":
         print(
             "Default Cases not learned yet.  Make sure you spedify your run type."
         )
     if type == "serial":
         PTI.ptiSparseTensorDotMul(result.address, self.address,
                                   tensor.address)
     elif type == "serial_EQ":
         PTI.ptiSparseTensorDotMulEq(result.address, self.address,
                                     tensor.address)
     elif type == "CPU":
         PTI.ptiOmpSparseTensorDotMulEq(result.address, self.address,
                                        tensor.address)
     elif type == "GPU":
         PTI.ptiCudaSparseTensorDotDiv(result.address, self.address,
                                       tensor.address)
     else:
         exit("Invalid Type")
     if not testing: return result
     else: PTI.ptiFreeSparseTensor(result.address)
Exemplo n.º 3
0
 def free(self):
     PTI.ptiFreeSparseTensor(self.address)
Exemplo n.º 4
0
 def dotDivTensor(self, tensor, type="default", testing=False):
     result = COOTensor()
     PTI.ptiSparseTensorDotDiv(result.address, self.address, tensor.address)
     if not testing: return result
     else: PTI.ptiFreeSparseTensor(result.address)