示例#1
0
文件: ofblob.py 项目: zyg11/oneflow
    def shape_list(self):
        tensor_shape_list = []

        num_axes = oneflow_api.OfBlob_NumAxes(self.of_blob_ptr_)
        oneflow_api.OfBlob_ResetTensorIterator(self.of_blob_ptr_)
        while not oneflow_api.OfBlob_CurTensorIteratorEqEnd(self.of_blob_ptr_):
            shape_tensor = np.zeros(self.num_axes, dtype=np.int64)
            oneflow_api.OfBlob_CurTensorCopyShapeTo(self.of_blob_ptr_,
                                                    shape_tensor)
            assert len(shape_tensor.shape) == 1
            assert shape_tensor.size == num_axes
            tensor_shape_list.append(tuple(shape_tensor.tolist()))
            oneflow_api.OfBlob_IncTensorIterator(self.of_blob_ptr_)

        return tensor_shape_list
示例#2
0
 def num_axes(self):
     return oneflow_api.OfBlob_NumAxes(self.of_blob_ptr_)
示例#3
0
 def set_shape(self, shape):
     assert isinstance(shape, (list, tuple))
     assert len(shape) == oneflow_api.OfBlob_NumAxes(self.of_blob_ptr_)
     oneflow_api.OfBlob_CopyShapeFrom(self.of_blob_ptr_,
                                      np.array(shape, dtype=np.int64))
示例#4
0
 def shape(self):
     num_axes = oneflow_api.OfBlob_NumAxes(self.of_blob_ptr_)
     dst_ndarray = np.zeros(num_axes, dtype=np.int64)
     oneflow_api.OfBlob_CopyShapeTo(self.of_blob_ptr_, dst_ndarray)
     return tuple(dst_ndarray.tolist())