Beispiel #1
0
    def _get_data_from_ndarray(array):
        if str(array.dtype).startswith('int'):
            dtype = Storage.INT
        elif str(array.dtype).startswith('float'):
            dtype = Storage.FLOAT
        elif str(array.dtype).startswith('bool'):
            dtype = Storage.BOOL
        else:
            raise NotImplementedError()

        data = trtc.device_vector_from_numpy(array.astype(dtype).ravel())
        return data, array.shape, dtype
Beispiel #2
0
 def __setitem__(self, key, value):
     if hasattr(value, 'data') and hasattr(value, 'shape') and len(value.shape) != 0:
         if isinstance(value, np.ndarray):
             vector = trtc.device_vector_from_numpy(value)
             trtc.Copy(vector, self.data)
         else:
             trtc.Copy(value.data, self.data)
     else:
         if isinstance(value, int):
             dvalue = trtc.DVInt64(value)
         elif isinstance(value, float):
             dvalue = PrecisionResolver.get_floating_point(value)
         else:
             raise TypeError("Only Storage, int and float are supported.")
         trtc.Fill(self.data, dvalue)
     return self
Beispiel #3
0
 def upload(self, data):
     trtc.Copy(
         trtc.device_vector_from_numpy(data.astype(self.dtype).ravel()),
         self.data
     )
Beispiel #4
0
 def ravel(self, other):
     if isinstance(other, Storage):
         trtc.Copy(other.data, self.data)
     else:
         self.data = trtc.device_vector_from_numpy(other.ravel())