def _rewrite(cls, node: OnnxNode): if 'axis' not in node.attrs: node.attrs['axis'] = 0 if 'split' in node.attrs: split = node.attrs['split'] node.attrs['split'] = tuple(split) node.attrs['n_out'] = node.len_outputs
def _rewrite(cls, node: OnnxNode): if 'auto_pad' not in node.attrs: node.attrs['auto_pad'] = 'NOTSET' if 'ceil_mode' not in node.attrs: node.attrs['ceil_mode'] = 0 if 'storage_order' not in node.attrs: node.attrs['storage_order'] = 0
def _rewrite(cls, node: OnnxNode): if 'axis' not in node.attrs: node.attrs['axis'] = 0 if 'keepdims' not in node.attrs: node.attrs['keepdims'] = 1 if 'select_last_index' not in node.attrs: node.attrs['select_last_index'] = 0
def _rewrite(cls, node: OnnxNode): if 'axis' not in node.attrs: node.attrs['axis'] = None else: axis = node.attrs.get('axis') if isinstance(axis, Sequence): node.attrs['axis'] = tuple(axis)
def _rewrite(cls, node: OnnxNode): if 'axes' in node.attrs: axes = node.attrs['axes'] node.attrs['axes'] = tuple(axes) if 'keepdims' not in node.attrs: node.attrs['keepdims'] = 1 if 'noop_with_empty_axes' not in node.attrs: node.attrs['noop_with_empty_axes'] = 0
def _rewrite(cls, node: OnnxNode): if 'alpha' not in node.attrs: node.attrs['alpha'] = 1.0 if 'beta' not in node.attrs: node.attrs['beta'] = 1.0 if 'transA' not in node.attrs: node.attrs['transA'] = 0 if 'transB' not in node.attrs: node.attrs['transB'] = 0
def _rewrite(cls, node: OnnxNode): if 'auto_pad' not in node.attrs: node.attrs['auto_pad'] = 'NOTSET' if 'ceil_mode' not in node.attrs: node.attrs['ceil_mode'] = 0 if 'count_include_pad' not in node.attrs: node.attrs['count_include_pad'] = 0 if 'pads' not in node.attrs: node.attrs['pads'] = None if 'strides' not in node.attrs: node.attrs['strides'] = None
def _rewrite(cls, node: OnnxNode): if 'group' not in node.attrs: node.attrs['group'] = 1 if 'pads' in node.attrs: pads = node.attrs['pads'] pads_new = [ (0, 0, 0), (0, 0, 0), (pads[0], pads[2], 0), (pads[1], pads[3], 0), ] node.attrs['pads'] = tuple(pads_new)
def _rewrite(cls, node: OnnxNode): if 'mode' not in node.attrs: node.attrs['mode'] = 'constant' if 'constant_value' not in node.attrs: node.attrs['constant_value'] = 0.0 # opset-v1 if 'paddings' in node.attrs: node.attrs['pads'] = node.attrs['paddings'] # opset-v1 & v2 if 'value' in node.attrs: node.attrs['constant_value'] = node.attrs['value']
def run_node(cls, node, inputs, device='CPU', **kwargs): onnx_node = OnnxNode(node) jit_func = cls._jit(onnx_node, **kwargs) inputs = [jnp.asarray(x) for x in inputs] # TODO support uncertain number inputs, like concat outputs = jit_func(*inputs, *onnx_node.attrs_list) return outputs if isinstance(outputs, Sequence) else [outputs]
def _rewrite(cls, node: OnnxNode): training_mode = node.attrs.get('training_mode', False) if training_mode: raise NotImplemented('Dropout training mode') ratio = node.attrs.get('ratio', 0.0) if ratio != 0.0: logger.warning(f"Dropout, change ratio from {ratio:.4f} to 0.0") node.attrs['return_mask'] = True if node.len_outputs == 2 else False
def _rewrite(cls, node: OnnxNode): to = node.attrs['to'] node.attrs['to'] = TENSOR_TYPE_TO_JNP_TYPE[to]
def _rewrite(cls, node: OnnxNode): if 'alpha' not in node.attrs: node.attrs['alpha'] = 1.0
def _rewrite(cls, node: OnnxNode): if 'broadcast' not in node.attrs: node.attrs['broadcast'] = False
def _rewrite(cls, node: OnnxNode): if 'axes' in node.attrs: axes = node.attrs['axes'] node.attrs['axes'] = tuple(axes)
def _rewrite(cls, node: OnnxNode): if 'epsilon' not in node.attrs: node.attrs['epsilon'] = 1e-5
def _rewrite(cls, node: OnnxNode): if 'alpha' not in node.attrs: node.attrs['alpha'] = 0.2 if 'beta' not in node.attrs: node.attrs['beta'] = 0.5
def _rewrite(cls, node: OnnxNode): if 'axis' not in node.attrs: node.attrs['axis'] = 0
def _rewrite(cls, node: OnnxNode): if 'mode' not in node.attrs: node.attrs['mode'] = 'nearest'
def _rewrite(cls, node: OnnxNode): if 'alpha' not in node.attrs: node.attrs['alpha'] = 1.6732632423543772848170429916717 if 'gamma' not in node.attrs: node.attrs['gamma'] = 1.0507009873554804934193349852946
def run_model(cls, model, inputs, device='CPU', **kwargs): def _asarray(proto): return jnp.asarray( numpy_helper.to_array(proto).reshape(tuple(proto.dims))) tensor_ref_dict = build_ref_dict(model) graph = model.graph if model.ir_version < 3: opset = [make_opsetid(defs.ONNX_DOMAIN, 1)] else: opset = model.opset_import if isinstance(inputs, dict): tensor_dict = dict( {k: v for k, v in inputs.items()}, **{n.name: _asarray(n) for n in graph.initializer}, ) else: graph_inputs = [x.name for x in graph.input] tensor_dict = dict( {k: v for k, v in zip(graph_inputs, inputs)}, **{n.name: _asarray(n) for n in graph.initializer}, ) jit_funcs = {} onnx_nodes = {} handlers = cls._get_handlers(opset) for idx, node in enumerate(graph.node): onnx_node = OnnxNode(node) jit_func = cls._jit(onnx_node, handlers=handlers, **kwargs) # in some early onnx versions, node has no name if node.name == '': node.name = f"{node.output[0]}" jit_funcs[node.name] = jit_func onnx_nodes[node.name] = onnx_node ref_dict = {} for node in graph.node: onnx_node = onnx_nodes[node.name] logger.info(f"running: {node.op_type}, {node.name}") node_inputs = [tensor_dict[x] for x in node.input] jit_func = jit_funcs[node.name] outputs = jit_func(*node_inputs, *onnx_node.attrs_list) outputs = outputs if isinstance(outputs, Sequence) else [outputs] for name, output in zip(node.output, outputs): tensor_dict[name] = output node_input_shapes = [tensor_dict[x].shape for x in node.input] node_output_shapes = [ tensor_dict[x].shape for x in node.output ] logger.info(f"\t{node_input_shapes} -> {node_output_shapes}") for input_ in node.input: if input_ in ref_dict: ref_dict[input_] += 1 else: ref_dict[input_] = 1 remove_keys = [] for k, v in ref_dict.items(): if tensor_ref_dict[k] == v: remove_keys.append(k) for rm_k in remove_keys: del ref_dict[rm_k] del tensor_dict[rm_k] return [tensor_dict[n.name] for n in graph.output]
def _rewrite(cls, node: OnnxNode): if 'fmod' not in node.attrs: node.attrs['fmod'] = 0
def _rewrite(cls, node: OnnxNode): if 'min' not in node.attrs: node.attrs['min'] = jnp.finfo(jnp.float32).min if 'max' not in node.attrs: node.attrs['max'] = jnp.finfo(jnp.float32).max