import numpy as np import mxnet as mx import mxnet.ndarray as nd import numbers import builtins from ... import ndarray as dglnd from ... import kernel as K from ...function.base import TargetCode MX_VERSION = LooseVersion(mx.__version__) if MX_VERSION.version[0] == 1 and MX_VERSION.version[1] < 5: raise Exception("DGL has to work with MXNet version >= 1.5") # After MXNet 1.5, empty tensors aren't supprted by default. # After we turn on the numpy compatible flag, MXNet supports empty NDArray. mx.set_np_shape(bool(os.environ.get('DGL_MXNET_SET_NP_SHAPE', True))) def data_type_dict(): return { 'float16': np.float16, 'float32': np.float32, 'float64': np.float64, 'uint8': np.uint8, 'int8': np.int8, 'int16': np.int16, 'int32': np.int32, 'int64': np.int64 }
from __future__ import absolute_import from distutils.version import LooseVersion import numpy as np import mxnet as mx import mxnet.ndarray as nd import numbers MX_VERSION = LooseVersion(mx.__version__) # After MXNet 1.5, empty tensors aren't supprted by default. # after we turn on the numpy compatible flag, MXNet supports empty NDArray. if MX_VERSION.version[0] == 1 and MX_VERSION.version[1] >= 5: mx.set_np_shape(True) def data_type_dict(): return { 'float16': np.float16, 'float32': np.float32, 'float64': np.float64, 'uint8': np.uint8, 'int8': np.int8, 'int16': np.int16, 'int32': np.int32, 'int64': np.int64 } def cpu(): return mx.cpu()