def shape_list(self): tensor_shape_list = [] num_axes = oneflow_api.OfBlob_NumAxes(self.of_blob_ptr_) oneflow_api.OfBlob_ResetTensorIterator(self.of_blob_ptr_) while not oneflow_api.OfBlob_CurTensorIteratorEqEnd(self.of_blob_ptr_): shape_tensor = np.zeros(self.num_axes, dtype=np.int64) oneflow_api.OfBlob_CurTensorCopyShapeTo(self.of_blob_ptr_, shape_tensor) assert len(shape_tensor.shape) == 1 assert shape_tensor.size == num_axes tensor_shape_list.append(tuple(shape_tensor.tolist())) oneflow_api.OfBlob_IncTensorIterator(self.of_blob_ptr_) return tensor_shape_list
def num_axes(self): return oneflow_api.OfBlob_NumAxes(self.of_blob_ptr_)
def set_shape(self, shape): assert isinstance(shape, (list, tuple)) assert len(shape) == oneflow_api.OfBlob_NumAxes(self.of_blob_ptr_) oneflow_api.OfBlob_CopyShapeFrom(self.of_blob_ptr_, np.array(shape, dtype=np.int64))
def shape(self): num_axes = oneflow_api.OfBlob_NumAxes(self.of_blob_ptr_) dst_ndarray = np.zeros(num_axes, dtype=np.int64) oneflow_api.OfBlob_CopyShapeTo(self.of_blob_ptr_, dst_ndarray) return tuple(dst_ndarray.tolist())