def split_params(sym, params): """Helper function to split params dictionary into args and aux params Parameters ---------- sym : :class:`~mxnet.symbol.Symbol` MXNet symbol object params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format Returns ------- arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray` Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` format """ arg_params = {} aux_params = {} for args in sym.list_arguments(): if args in params: arg_params.update({args: nd.array(params[args])}) for aux in sym.list_auxiliary_states(): if aux in params: aux_params.update({aux: nd.array(params[aux])}) return arg_params, aux_params
def from_onnx(self, graph): """Construct symbol from onnx graph. Parameters ---------- graph : onnx protobuf object The loaded onnx graph Returns ------- sym :symbol.Symbol The returned mxnet symbol params : dict A dict of name: nd.array pairs, used as pretrained weights """ # get input, output shapes self.model_metadata = self.get_graph_metadata(graph) # parse network inputs, aka parameters for init_tensor in graph.initializer: if not init_tensor.name.strip(): raise ValueError("Tensor's name is required.") self._params[init_tensor.name] = self._parse_array(init_tensor) # converting GraphProto message for i in graph.input: if i.name in self._params: # i is a param instead of input self._nodes[i.name] = symbol.Variable(name=i.name, shape=self._params[i.name].shape) else: self._nodes[i.name] = symbol.Variable(name=i.name) # constructing nodes, nodes are stored as directed acyclic graph # converting NodeProto message for node in graph.node: op_name = node.op_type node_name = node.name.strip() node_name = node_name if node_name else None onnx_attr = self._parse_attr(node.attribute) inputs = [self._nodes[i] for i in node.input] mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, inputs) for k, i in zip(list(node.output), range(len(mxnet_sym.list_outputs()))): self._nodes[k] = mxnet_sym[i] # splitting params into args and aux params for args in mxnet_sym.list_arguments(): if args in self._params: self.arg_dict.update({args: nd.array(self._params[args])}) for aux in mxnet_sym.list_auxiliary_states(): if aux in self._params: self.aux_dict.update({aux: nd.array(self._params[aux])}) # now return the outputs out = [self._nodes[i.name] for i in graph.output] if len(out) > 1: out = symbol.Group(out) else: out = out[0] return out, self.arg_dict, self.aux_dict
def _prepare_image(img, nrow=8, padding=2, square_image=False): """Given an image of format HW, CHW, or NCHW, returns a image of format HWC. If the input is a batch of images, a grid of images is made by stitching them together. If data type is float, values must be in the range [0, 1], and then they are rescaled to range [0, 255]. If data type is 'uint8`, values are unchanged. """ if isinstance(img, np.ndarray): img = nd.array(img, dtype=img.dtype, ctx=current_context()) if not isinstance(img, NDArray): raise TypeError('expected MXNet NDArray or numpy.ndarray, ' 'while received type {}'.format(str(type(img)))) assert img.ndim == 2 or img.ndim == 3 or img.ndim == 4 if img.dtype == np.uint8: return make_image_grid( img, nrow=nrow, padding=padding, square_image=square_image).transpose((1, 2, 0)) elif img.dtype == np.float32 or img.dtype == np.float64: min_val = img.min().asscalar() max_val = img.max().asscalar() if min_val < 0.0: raise ValueError('expected non-negative min value from img, ' 'while received {}'.format(min_val)) if max_val > 1.0: raise ValueError('expected max value from img not greater than 1, ' 'while received {}'.format(max_val)) img = make_image_grid(img, nrow=nrow, padding=padding, square_image=square_image) * 255.0 return img.astype(np.uint8).transpose((1, 2, 0)) else: raise ValueError('expected input image dtype is one of uint8, float32, ' 'and float64, received dtype {}'.format(str(img.dtype)))
def _parse_array(self, tensor_proto): """Grab data in TensorProto and convert to numpy array.""" try: from onnx.numpy_helper import to_array except ImportError: raise ImportError("Onnx and protobuf need to be installed. " + "Instructions to install - https://github.com/onnx/onnx") if len(tuple(tensor_proto.dims)) > 0: np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims)) else: # If onnx's params are scalar values without dims mentioned. np_array = np.array([to_array(tensor_proto)]) return nd.array(np_array)
def _make_sprite_image(images, save_path): """Given an NDArray as a batch images, make a sprite image out of it following the rule defined in https://www.tensorflow.org/programmers_guide/embedding and save it in sprite.png under the path provided by the user.""" if isinstance(images, np.ndarray): images = nd.array(images, dtype=images.dtype, ctx=current_context()) elif not isinstance(images, (NDArray, np.ndarray)): raise TypeError('images must be an MXNet NDArray or numpy.ndarray,' ' while received type {}'.format(str(type(images)))) assert isinstance(images, NDArray) shape = images.shape nrow = int(np.ceil(np.sqrt(shape[0]))) _save_image(images, os.path.join(save_path, 'sprite.png'), nrow=nrow, padding=0)