Пример #1
0
def load(fname):
    """Loads an array from file.

    See more details in ``save``.

    Parameters
    ----------
    fname : str
        The filename.

    Returns
    -------
    list of NDArray or dict of str to NDArray
        Loaded data.
    """
    if not isinstance(fname, string_types):
        raise TypeError('fname required to be a string')
    out_size = mx_uint()
    out_name_size = mx_uint()
    handles = ctypes.POINTER(NDArrayHandle)()
    names = ctypes.POINTER(ctypes.c_char_p)()
    check_call(
        _LIB.MXNDArrayLoad(c_str(fname), ctypes.byref(out_size),
                           ctypes.byref(handles), ctypes.byref(out_name_size),
                           ctypes.byref(names)))
    if out_name_size.value == 0:
        return [
            _ndarray_cls(NDArrayHandle(handles[i]))
            for i in range(out_size.value)
        ]
    else:
        assert out_name_size.value == out_size.value
        return dict((py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i])))
                    for i in range(out_size.value))
Пример #2
0
def _init_symbol_module(root_namespace):
    """List and add all the atomic symbol functions to current module."""
    plist = ctypes.POINTER(ctypes.c_char_p)()
    size = ctypes.c_uint()

    check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
                                     ctypes.byref(plist)))
    op_names = []
    for i in range(size.value):
        op_names.append(py_str(plist[i]))

    module_obj = _sys.modules["%s.symbol" % root_namespace]
    module_sparse = _sys.modules["%s.symbol.sparse" % root_namespace]
    module_internal = _sys.modules["%s.symbol._internal" % root_namespace]
    module_contrib = _sys.modules["%s.contrib.symbol" % root_namespace]
    for name in op_names:
        hdl = OpHandle()
        check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
        function = _make_atomic_symbol_function(hdl, name)
        if function.__name__.startswith('_contrib_'):
            function.__name__ = function.__name__[9:]
            function.__module__ = 'mxnet.contrib.symbol'
            setattr(module_contrib, function.__name__, function)
        elif function.__name__.startswith('_'):
            setattr(module_internal, function.__name__, function)
        else:
            setattr(module_obj, function.__name__, function)

        # register sparse ops under mxnet.symbol.sparse
        if function.__name__.startswith('_sparse_'):
            function.__name__ = function.__name__[8:]
            function.__module__ = 'mxnet.symbol.sparse'
            setattr(module_sparse, function.__name__, function)
Пример #3
0
            def collect(self, name, arr):
                """Callback function for collecting min and max values from an NDArray."""
                name = py_str(name)
                if name not in self.include_layer:
                    return

                handle = ctypes.cast(arr, NDArrayHandle)
                arr = mx.ndarray.NDArray(handle, writable=False).asnumpy()  # pylint: disable=no-member
                if name in self.tensor_dict.keys():
                    self.tensor_dict[name].append(arr)
                else:
                    self.tensor_dict[name] = [arr]
Пример #4
0
 def collect(self, name, arr):
     """Callback function for collecting layer output NDArrays."""
     name = py_str(name)
     if self.include_layer is not None and not self.include_layer(name):
         return
     handle = ctypes.cast(arr, NDArrayHandle)
     arr = NDArray(handle, writable=False).copyto(cpu())
     if self.logger is not None:
         self.logger.info("Collecting layer %s output of shape %s" %
                          (name, arr.shape))
     if name in self.nd_dict:
         self.nd_dict[name].append(arr)
     else:
         self.nd_dict[name] = [arr]
Пример #5
0
def _get_op_arguments(op_handle):
    """Given operator name and handle, fetch operator arguments - number of arguments,
    argument names, argument types.

    Parameters
    ----------
    op_handle: OpHandle
        Handle for the operator

    Returns
    -------
    (narg, arg_names, arg_types)
    """
    real_name = ctypes.c_char_p()
    desc = ctypes.c_char_p()
    num_args = mx_uint()
    arg_names = ctypes.POINTER(ctypes.c_char_p)()
    arg_types = ctypes.POINTER(ctypes.c_char_p)()
    arg_descs = ctypes.POINTER(ctypes.c_char_p)()
    key_var_num_args = ctypes.c_char_p()
    ret_type = ctypes.c_char_p()

    check_call(
        _LIB.MXSymbolGetAtomicSymbolInfo(op_handle, ctypes.byref(real_name),
                                         ctypes.byref(desc),
                                         ctypes.byref(num_args),
                                         ctypes.byref(arg_names),
                                         ctypes.byref(arg_types),
                                         ctypes.byref(arg_descs),
                                         ctypes.byref(key_var_num_args),
                                         ctypes.byref(ret_type)))

    narg = int(num_args.value)
    arg_names = [py_str(arg_names[i]) for i in range(narg)]
    arg_types = [py_str(arg_types[i]) for i in range(narg)]

    return narg, arg_names, arg_types
Пример #6
0
def _get_all_registered_ops():
    """Get all registered MXNet operator names.


    Returns
    -------
    ["operator_name"]
    """
    plist = ctypes.POINTER(ctypes.c_char_p)()
    size = ctypes.c_uint()

    check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
                                     ctypes.byref(plist)))

    mx_registered_operator_names = [py_str(plist[i]) for i in range(size.value)]
    return mx_registered_operator_names
Пример #7
0
 def collect(self, name, arr):
     """Callback function for collecting min and max values from an NDArray."""
     name = py_str(name)
     if self.include_layer is not None and not self.include_layer(name):
         return
     handle = ctypes.cast(arr, NDArrayHandle)
     arr = NDArray(handle, writable=False)
     min_range = ndarray.min(arr).asscalar()
     max_range = ndarray.max(arr).asscalar()
     if name in self.min_max_dict:
         cur_min_max = self.min_max_dict[name]
         self.min_max_dict[name] = (min(cur_min_max[0], min_range),
                                    max(cur_min_max[1], max_range))
     else:
         self.min_max_dict[name] = (min_range, max_range)
     if self.logger is not None:
         self.logger.info("Collecting layer %s min_range=%f, max_range=%f" %
                          (name, min_range, max_range))
Пример #8
0
def str_updater(key, recv, local):
    """use updater: += with str keys"""
    if isinstance(key, bytes):
        key = py_str(key)
    assert(isinstance(key, str))
    local += recv
Пример #9
0
def _make_atomic_symbol_function(handle, name):
    """Create an atomic symbol function by handle and function name."""
    real_name = ctypes.c_char_p()
    desc = ctypes.c_char_p()
    num_args = mx_uint()
    arg_names = ctypes.POINTER(ctypes.c_char_p)()
    arg_types = ctypes.POINTER(ctypes.c_char_p)()
    arg_descs = ctypes.POINTER(ctypes.c_char_p)()
    key_var_num_args = ctypes.c_char_p()
    ret_type = ctypes.c_char_p()

    check_call(_LIB.MXSymbolGetAtomicSymbolInfo(
        handle, ctypes.byref(real_name), ctypes.byref(desc),
        ctypes.byref(num_args),
        ctypes.byref(arg_names),
        ctypes.byref(arg_types),
        ctypes.byref(arg_descs),
        ctypes.byref(key_var_num_args),
        ctypes.byref(ret_type)))
    narg = int(num_args.value)
    arg_names = [py_str(arg_names[i]) for i in range(narg)]
    arg_types = [py_str(arg_types[i]) for i in range(narg)]
    func_name = name
    key_var_num_args = py_str(key_var_num_args.value)
    ret_type = py_str(ret_type.value) if ret_type.value is not None else ''
    doc_str = _build_doc(func_name,
                         py_str(desc.value),
                         arg_names,
                         arg_types,
                         [py_str(arg_descs[i]) for i in range(narg)],
                         key_var_num_args,
                         ret_type)

    dtype_name = None
    arr_name = None
    ndsignature = []
    signature = []
    ndarg_names = []
    kwarg_names = []
    for i in range(narg):
        name, atype = arg_names[i], arg_types[i]
        if name == 'dtype':
            dtype_name = name
            signature.append('%s=_Null'%name)
        elif atype.startswith('NDArray') or atype.startswith('Symbol'):
            assert not arr_name, \
                "Op can only have one argument with variable " \
                "size and it must be the last argument."
            if atype.endswith('[]'):
                ndsignature.append('*%s'%name)
                arr_name = name
            else:
                ndsignature.append('%s=None'%name)
                ndarg_names.append(name)
        else:
            signature.append('%s=_Null'%name)
            kwarg_names.append(name)
    #signature.append('is_train=False')
    signature.append('name=None')
    signature.append('attr=None')
    signature.append('out=None')
    signature.append('**kwargs')
    signature = ndsignature + signature

    code = []
    if arr_name:
        code.append("""
def %s(*%s, **kwargs):"""%(func_name, arr_name))
        code.append("""
    sym_args = []
    for i in {}:
        assert isinstance(i, SymbolBase), \\
            "Positional arguments must be Symbol instances, " \\
            "but got %s"%str(i)
        sym_args.append(i)""".format(arr_name))
        if dtype_name is not None:
            code.append("""
    if '%s' in kwargs:
        kwargs['%s'] = _numpy.dtype(kwargs['%s']).name"""%(
            dtype_name, dtype_name, dtype_name))
        code.append("""
    attr = kwargs.pop('attr', None)
    kwargs.update(AttrScope.current.get(attr))
    name = kwargs.pop('name', None)
    name = NameManager.current.get(name, '%s')
    _ = kwargs.pop('out', None)
    keys = []
    vals = []
    sym_kwargs = dict()
    for k, v in kwargs.items():
        if isinstance(v, SymbolBase):
            sym_kwargs[k] = v
        else:
            keys.append(k)
            vals.append(v)"""%(func_name.lower()))
        if key_var_num_args:
            code.append("""
    if '%s' not in kwargs:
        keys.append('%s')
        vals.append(len(sym_args) + len(sym_kwargs))"""%(
            key_var_num_args, key_var_num_args))

        code.append("""
    return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name)"""%(
        handle.value))
    else:
        code.append("""
def %s(%s):
    kwargs.update(AttrScope.current.get(attr))
    sym_kwargs = dict()
    keys = []
    vals = []"""%(func_name, ', '.join(signature)))
        code.append("""
    for k, v in kwargs.items():
        if isinstance(v, SymbolBase):
            sym_kwargs[k] = v
        else:
            keys.append(k)
            vals.append(v)""")
        # NDArray args
        for name in ndarg_names: # pylint: disable=redefined-argument-from-local
            code.append("""
    if {name} is not None:
        assert isinstance({name}, SymbolBase), \\
            "Argument {name} must be Symbol instances, but got %s"%str({name})
        sym_kwargs['{name}'] = {name}""".format(name=name))
        # kwargs
        for name in kwarg_names: # pylint: disable=redefined-argument-from-local
            code.append("""
    if %s is not _Null:
        keys.append('%s')
        vals.append(%s)"""%(name, name, name))
        # dtype
        if dtype_name is not None:
            code.append("""
    if %s is not _Null:
        keys.append('%s')
        vals.append(_numpy.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))

        code.append("""
    name = NameManager.current.get(name, '%s')
    return _symbol_creator(%d, None, sym_kwargs, keys, vals, name)"""%(
        func_name.lower(), handle.value))

    local = {}
    exec(''.join(code), None, local)  # pylint: disable=exec-used
    symbol_function = local[func_name]
    symbol_function.__name__ = func_name
    symbol_function.__doc__ = doc_str
    symbol_function.__module__ = 'mxnet.symbol'
    return symbol_function
Пример #10
0
 def stat_helper(name, array):
     array = ctypes.cast(array, NDArrayHandle)
     array = NDArray(array, writable=False)
     if not self.activated or not self.re_prog.match(py_str(name)):
         return
     self.queue.append((self.step, py_str(name), stat(array)))