Exemple #1
0
    def split_params(sym, params):
        """Helper function to split params dictionary into args and aux params

        Parameters
        ----------
        sym : :class:`~mxnet.symbol.Symbol`
            MXNet symbol object
        params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format

        Returns
        -------
        arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format
        aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format
        """
        arg_params = {}
        aux_params = {}
        for args in sym.list_arguments():
            if args in params:
                arg_params.update({args: nd.array(params[args])})
        for aux in sym.list_auxiliary_states():
            if aux in params:
                aux_params.update({aux: nd.array(params[aux])})
        return arg_params, aux_params
Exemple #2
0
    def from_onnx(self, graph):
        """Construct symbol from onnx graph.

        Parameters
        ----------
        graph : onnx protobuf object
            The loaded onnx graph

        Returns
        -------
        sym :symbol.Symbol
            The returned mxnet symbol
        params : dict
            A dict of name: nd.array pairs, used as pretrained weights
        """
        # get input, output shapes
        self.model_metadata = self.get_graph_metadata(graph)
        # parse network inputs, aka parameters
        for init_tensor in graph.initializer:
            if not init_tensor.name.strip():
                raise ValueError("Tensor's name is required.")
            self._params[init_tensor.name] = self._parse_array(init_tensor)

        # converting GraphProto message
        for i in graph.input:
            if i.name in self._params:
                # i is a param instead of input
                self._nodes[i.name] = symbol.Variable(name=i.name,
                                                      shape=self._params[i.name].shape)
            else:
                self._nodes[i.name] = symbol.Variable(name=i.name)

        # constructing nodes, nodes are stored as directed acyclic graph
        # converting NodeProto message
        for node in graph.node:
            op_name = node.op_type
            node_name = node.name.strip()
            node_name = node_name if node_name else None
            onnx_attr = self._parse_attr(node.attribute)
            inputs = [self._nodes[i] for i in node.input]
            mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, inputs)

            for k, i in zip(list(node.output), range(len(mxnet_sym.list_outputs()))):
                self._nodes[k] = mxnet_sym[i]

            # splitting params into args and aux params
            for args in mxnet_sym.list_arguments():
                if args in self._params:
                    self.arg_dict.update({args: nd.array(self._params[args])})
            for aux in mxnet_sym.list_auxiliary_states():
                if aux in self._params:
                    self.aux_dict.update({aux: nd.array(self._params[aux])})

        # now return the outputs
        out = [self._nodes[i.name] for i in graph.output]
        if len(out) > 1:
            out = symbol.Group(out)
        else:
            out = out[0]
        return out, self.arg_dict, self.aux_dict
Exemple #3
0
def _prepare_image(img, nrow=8, padding=2, square_image=False):
    """Given an image of format HW, CHW, or NCHW, returns a image of format HWC.
    If the input is a batch of images, a grid of images is made by stitching them together.
    If data type is float, values must be in the range [0, 1], and then they are rescaled to
    range [0, 255]. If data type is 'uint8`, values are unchanged.
    """
    if isinstance(img, np.ndarray):
        img = nd.array(img, dtype=img.dtype, ctx=current_context())
    if not isinstance(img, NDArray):
        raise TypeError('expected MXNet NDArray or numpy.ndarray, '
                        'while received type {}'.format(str(type(img))))
    assert img.ndim == 2 or img.ndim == 3 or img.ndim == 4

    if img.dtype == np.uint8:
        return make_image_grid(
            img, nrow=nrow, padding=padding, square_image=square_image).transpose((1, 2, 0))
    elif img.dtype == np.float32 or img.dtype == np.float64:
        min_val = img.min().asscalar()
        max_val = img.max().asscalar()
        if min_val < 0.0:
            raise ValueError('expected non-negative min value from img, '
                             'while received {}'.format(min_val))
        if max_val > 1.0:
            raise ValueError('expected max value from img not greater than 1, '
                             'while received {}'.format(max_val))
        img = make_image_grid(img, nrow=nrow, padding=padding, square_image=square_image) * 255.0
        return img.astype(np.uint8).transpose((1, 2, 0))
    else:
        raise ValueError('expected input image dtype is one of uint8, float32, '
                         'and float64, received dtype {}'.format(str(img.dtype)))
Exemple #4
0
 def _parse_array(self, tensor_proto):
     """Grab data in TensorProto and convert to numpy array."""
     try:
         from onnx.numpy_helper import to_array
     except ImportError:
         raise ImportError("Onnx and protobuf need to be installed. "
                           + "Instructions to install - https://github.com/onnx/onnx")
     if len(tuple(tensor_proto.dims)) > 0:
         np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
     else:
         # If onnx's params are scalar values without dims mentioned.
         np_array = np.array([to_array(tensor_proto)])
     return nd.array(np_array)
Exemple #5
0
def _make_sprite_image(images, save_path):
    """Given an NDArray as a batch images, make a sprite image out of it following the rule
    defined in
    https://www.tensorflow.org/programmers_guide/embedding
    and save it in sprite.png under the path provided by the user."""
    if isinstance(images, np.ndarray):
        images = nd.array(images, dtype=images.dtype, ctx=current_context())
    elif not isinstance(images, (NDArray, np.ndarray)):
        raise TypeError('images must be an MXNet NDArray or numpy.ndarray,'
                        ' while received type {}'.format(str(type(images))))

    assert isinstance(images, NDArray)
    shape = images.shape
    nrow = int(np.ceil(np.sqrt(shape[0])))
    _save_image(images, os.path.join(save_path, 'sprite.png'), nrow=nrow, padding=0)