def __setitem__(self, key, value): if hasattr(value, 'data') and hasattr(value, 'shape') and len(value.shape) != 0: if isinstance(value, np.ndarray): vector = trtc.device_vector_from_numpy(value) trtc.Copy(vector, self.data) else: trtc.Copy(value.data, self.data) else: if isinstance(value, int): dvalue = trtc.DVInt64(value) elif isinstance(value, float): dvalue = PrecisionResolver.get_floating_point(value) else: raise TypeError("Only Storage, int and float are supported.") trtc.Fill(self.data, dvalue) return self
def _to_host(self): if isinstance(self.data, trtc.DVVector.DVRange): if self.dtype is Storage.FLOAT: elem_cls = PrecisionResolver.get_C_type() elif self.dtype is Storage.INT: elem_cls = 'int64_t' elif self.dtype is Storage.BOOL: elem_cls = 'bool' else: raise NotImplementedError() data = trtc.device_vector(elem_cls, self.data.size()) trtc.Copy(self.data, data) else: data = self.data return data.to_host()
def upload(self, data): trtc.Copy( trtc.device_vector_from_numpy(data.astype(self.dtype).ravel()), self.data )
def ravel(self, other): if isinstance(other, Storage): trtc.Copy(other.data, self.data) else: self.data = trtc.device_vector_from_numpy(other.ravel())