Ejemplo n.º 1
0
    def _get_empty_data(shape, dtype):
        if dtype in (float, Storage.FLOAT):
            elem_cls = PrecisionResolver.get_C_type()
            dtype = Storage.FLOAT
        elif dtype in (int, Storage.INT):
            elem_cls = 'int64_t'
            dtype = Storage.INT
        elif dtype in (bool, Storage.BOOL):
            elem_cls = 'bool'
            dtype = Storage.BOOL
        else:
            raise NotImplementedError

        data = trtc.device_vector(elem_cls, int(np.prod(shape)))
        return data, shape, dtype
Ejemplo n.º 2
0
    def _to_host(self):
        if isinstance(self.data, trtc.DVVector.DVRange):
            if self.dtype is Storage.FLOAT:
                elem_cls = PrecisionResolver.get_C_type()
            elif self.dtype is Storage.INT:
                elem_cls = 'int64_t'
            elif self.dtype is Storage.BOOL:
                elem_cls = 'bool'
            else:
                raise NotImplementedError()

            data = trtc.device_vector(elem_cls, self.data.size())

            trtc.Copy(self.data, data)
        else:
            data = self.data
        return data.to_host()