def test_combine2(): dim1 = Dim.unnamed((1, 12800, 2)) dim2 = Dim.unnamed((1, 3200, 2)) dim3 = Dim.unnamed((1, 800, 2)) dim4 = Dim.unnamed((1, 200, 2)) res = Dim.combine((dim1, dim2, dim3, dim4), 1) assert res.shape == [1, 17000, 2]
def test_combine1(): dim1 = Dim.named_ordered(a=1, c=3, b=2) dim2 = Dim.named_ordered(a=1, c=3, b=2) dim3 = Dim.combine((dim1, dim2), 'c') assert dim3.shape == [1, 6, 2] dim3.c = 4 assert dim1.c == 3 and dim2.c == 3
def _common(cls, node, scales, sizes, nearest_mode='round_prefer_ceil', **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] valid_name = kwargs['valid_name'] inputs = [all_nodes[inp] if inp else None for inp in node.input] x = inputs[0] x_shape = x[2].shape x_rank = len(x_shape) spatial_size = x_rank - 2 in_c = x_shape[1] in_w = x_shape[-1] if scales is not None: sizes = np.array(x_shape) * np.array(scales) sizes = [None if x_shape[idx] is None else dim for idx, dim in enumerate(sizes)] if spatial_size == 1: sizes.insert(-1, 1) if nearest_mode != 'round_prefer_ceil': logger.warning('only round_prefer_ceil is supported for nearest mode') if spatial_size != 2 and spatial_size != 1: raise ValueError('resize only supports 4D tensor in NCHW mode or 3D tensor in NCF mode' f' - input shape is {x_shape} sizes is {sizes}') if not all(x_dim == size_dim for x_dim, size_dim in zip(x_shape[:2:], sizes[:2:])): raise ValueError('resize only supports 4D tensor in NCHW mode or 3D tensor in NCF mode' f' - input shape is {x_shape} sizes is {sizes}') mode = node.attrs.get('mode', 'nearest') if mode != 'nearest' and mode != 'linear': raise ValueError('resize only supports nearest and linear modes') params_class = BilinearResizerParameters if mode == 'linear' else NearestNeighborResizerParameters params = params_class(valid_name, new_shape=tuple(sizes[2::]), align_corners=False, halfpixel_centers=False, in_dims_hint=[['c', 'h', 'w']], out_dims_hint=[['c', 'h', 'w']]) if spatial_size == 1: r1_params = ReshapeParameters(f'{valid_name}_reshape2d', old_shape=Dim.unnamed([in_c, in_w]), shape=Dim.unnamed([in_c, 1, in_w])) r2_params = ReshapeParameters(f'{valid_name}_reshape1d', old_shape=Dim.unnamed([in_c, 1, sizes[-1]]), shape=Dim.unnamed([in_c, sizes[-1]])) G.add_edge(NNEdge(from_node=x[0], to_node=r1_params, from_idx=x[1], to_idx=0)) G.add_edge(NNEdge(from_node=r1_params, to_node=params, from_idx=0, to_idx=0)) G.add_edge(NNEdge(from_node=params, to_node=r2_params, from_idx=0, to_idx=0)) pout_dims = ProvisionalDim(sizes[:-2:] + sizes[-1::]) params = r2_params else: pout_dims = ProvisionalDim(sizes) G.add_edge(NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0)) all_nodes[node.output[0]] = (params, 0, pout_dims) return params
def test_operation2(): dim1 = Dim.named_ordered(a=1, c=3, b=2) dim2 = Dim.named_ordered(a=1, c=3, b=2) dim3 = dim1 - dim2 assert dim3.is_named assert dim3.is_ordered assert dim3.size() == 0
def _common(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] valid_name = kwargs['valid_name'] inputs = [all_nodes[inp] for inp in node.input] x = inputs[0] x_shape = cls._get_real_dim(x[2].shape) y = inputs[1] y_shape = cls._get_real_dim(y[2].shape) if cls.is_linear(y, x_shape, y_shape): filt_dim = FcFilterDim(y_shape[1], x_shape[0]) weights = np.transpose(cls.get_constant(y), [1, 0]) params = FcParameters(valid_name, filt=filt_dim, has_bias=False, in_dims_hint=SparseList([['c']]), out_dims_hint=SparseList([['c']]), constant_store=G.constant_store) params.weights = weights out_dims = params.get_output_size([Dim.unnamed(x_shape)]) else: params = MatMulOpParameters(valid_name) out_dims = params.get_output_size( [Dim.unnamed(x_shape), Dim.unnamed(y_shape)]) G.add_edge( NNEdge(from_node=y[0], to_node=params, from_idx=y[1], to_idx=1)) G.add_edge( NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0)) pout_dims = x[2].infer_mapping(out_dims[0].shape) all_nodes[node.output[0]] = (params, 0, pout_dims) return params
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): rnn_nodes = [ self.find_unpack(G, node) for node in G.nodes() if isinstance(node, RNNBaseParameters) and node.n_output_cells > 1 ] rnn_nodes_by_slice = self.validate_slices(G, rnn_nodes) rnn_nodes_by_slice = self.validate_multi_branch(G, rnn_nodes_by_slice) if not rnn_nodes_by_slice: return False for unpack_node, rnn_unpacks in rnn_nodes_by_slice.items(): modified_nodes = set() for rnn_unpack in rnn_unpacks: self.process_path(G, rnn_unpack, modified_nodes) # since process path will have removed all unnecessary nodes the edges will be correct here out_edges = G.out_edges(unpack_node.name) in_edges = G.in_edges(unpack_node.name) assert len(in_edges ) == 1, "expecting unpack node to have only one in edge" in_edge = in_edges[0] changes_shape = unpack_node.changes_shape if isinstance( unpack_node, StridedSliceParameters) else False LOG.info("Eliminating last cell unpack: %s", unpack_node.name) G.remove(unpack_node) # Here the strided slice can change the output shape of the RNN # so insert a reshape to do the shape change if changes_shape: reshape = ReshapeParameters( unpack_node.name + '_reshape', old_shape=Dim.unnamed(unpack_node.post_slice_shape), shape=Dim.unnamed(unpack_node.out_shape)) G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=reshape, from_idx=in_edge.from_idx)) for out_edge in out_edges: G.add_edge( NNEdge(from_node=reshape, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) if G.quantization: G.quantization[NodeId(reshape)] = G.quantization[NodeId( unpack)] else: for out_edge in out_edges: G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=out_edge.to_node, from_idx=in_edge.from_idx, to_idx=out_edge.to_idx)) if G.quantization: del G.quantization[NodeId(unpack_node)] if set_identity: self.set_identity(G) return True
def _get_initializers(self, initializer): return { init.name: (ConstantInputParameters(self._validate_name(init.name), dims=Dim.unnamed(init.dims or [1]), value=self._get_numpy_array(init)), 0, Dim.unnamed(init.dims)) for init in initializer }
def get_output_size(self, in_dims): num_detected_boxes = self._parameters['max_detections'] * \ self._parameters['max_classes_per_detection'] return [ Dim(shape=[num_detected_boxes, 4], is_ordered=True), Dim(shape=[num_detected_boxes], is_ordered=True), Dim(shape=[num_detected_boxes], is_ordered=True), Dim(shape=[num_detected_boxes], is_ordered=True), ]
def __init__(self, *args, old_shape=None, shape=None, **kwargs): super(ReshapeParameters, self).__init__(*args, **kwargs) if not isinstance(shape, Dim): shape = Dim.unnamed(shape) if old_shape is not None and not isinstance(old_shape, Dim): old_shape = Dim.unnamed(old_shape) assert shape.is_ordered and (old_shape is None or old_shape.is_ordered) self._shape = shape self._old_shape = old_shape
def _common(cls, node: TFLiteNode, **kwargs): node_opts = node.get_options(StridedSliceOptions) G = kwargs['G'] opts = kwargs['opts'] all_nodes = kwargs['all_nodes'] inputs = [all_nodes[t] for t in node.input] x = inputs[0] x_shape = x[2].shape # begin end stride vec_begin = list(cls._verify_constant(inputs[1])) vec_end = list(cls._verify_constant(inputs[2])) vec_stride = list(cls._verify_constant(inputs[3])) for i in range(1, 4): node.input[i].used = True if any([vec is None for vec in [vec_begin, vec_end, vec_stride]]): raise NotImplementedError( "strided slice with variable begin end or stride is not supported") spec = zip(vec_begin, vec_end, vec_stride) begin_mask = node_opts.BeginMask() ellipsis_mask = node_opts.EllipsisMask() end_mask = node_opts.EndMask() new_axis_mask = node_opts.NewAxisMask() shrink_axis_mask = node_opts.ShrinkAxisMask() act_slice, out_shape, can_reshape = StridedSliceParameters.get_slice( x_shape, spec, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask) if cls.is_constant(x): LOG.info("reducing %s to a constant", node.name) x_val = cls.get_constant(x) params = StridedSliceParameters(node.name, act_slice=act_slice, out_shape=out_shape) x_val = params.numpy_slice(x_val) params = ConstantInputParameters(node.name, value=x_val) else: if can_reshape: if list(x_shape) == list(out_shape): LOG.info("converting strided slice %s to a noop", node.name) params = NoOPParameters(node.name) else: LOG.info("converting strided slice %s to a reshape", node.name) in_shape = Dim.unnamed(x[2].known_shape, is_ordered=True) out_shape = Dim.unnamed(out_shape, is_ordered=True) params = ReshapeParameters(node.name, old_shape=in_shape, shape=out_shape) else: params = StridedSliceParameters(node.name, act_slice=act_slice, out_shape=out_shape) G.add_edge(NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0)) if opts.get('load_quantization'): G.quantization[NodeId(params)] = cls.load_tf_quantization([node.input[0]], node.output) all_nodes[node.output[0]] = (params, 0, x[2].infer_mapping(out_shape, allow_bad_length=True)) return params
def test_operation1(): dim1 = Dim.named_ordered(a=1, c=3, b=2) dim2 = Dim.named_ordered(a=1, c=3, b=2) dim3 = dim1 + dim2 assert dim3.is_named assert dim3.is_ordered assert dim3.a == 2 and dim3.b == 4 and dim3.c == 6 assert dim3.shape == [2, 6, 4] dim3.a = 2 assert dim1.a == 1 and dim2.a == 1
def _common(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] valid_name = kwargs['valid_name'] inputs = [all_nodes[inp] for inp in node.input] if cls.SINCE_VERSION == 1: shape = np.array(node.attrs["shape"]) else: # since_version >= 5 shape = cls.get_constant(inputs[1]) input_shape = np.array(inputs[0][2].shape) shape = [ dim if dim != 0 else input_shape[idx] for idx, dim in enumerate(shape) ] if -1 in shape: wild_index = shape.index(-1) in_size = prod([1 if dim is None else dim for dim in input_shape]) shape_size = prod( [1 if dim is None or dim <= 0 else dim for dim in shape]) if in_size % shape_size != 0: raise ValueError('invalid reshape') shape[wild_index] = in_size // shape_size shape = np.array(shape) if cls.is_constant(inputs[0]): logger.info("reducing %s to a constant", valid_name) params = ConstantInputParameters(valid_name, value=cls.get_constant( inputs[0]).reshape(shape), dims=Dim.unnamed(shape), constant_store=G.constant_store) pshape = ProvisionalDim(shape) all_nodes[node.output[0]] = (params, 0, pshape) return params # TODO - There must be a better way of doing this # This hacks around the fact that the batch dimension will be in the reshape if input_shape[0] is None and shape[0] == 1: shape = np.array([None] + list(shape[1::])) pshape = ProvisionalDim(shape) # pylint: disable=singleton-comparison old_shape = Dim.unnamed(list(input_shape[input_shape != None])) shape = Dim.unnamed(list(shape[shape != None])) params = ReshapeParameters(valid_name, old_shape=old_shape, shape=shape) inp = inputs[0] G.add_edge( NNEdge(from_node=inp[0], to_node=params, from_idx=inp[1], to_idx=0)) all_nodes[node.output[0]] = (params, 0, pshape) return params
def test_concat(): inputs = [np.full([1, 2, 2], 1.0), np.full([2, 2, 2], 2.0)] in_dims = [ Dim.named(c=1, h=2, w=2).impose_order(['c', 'h', 'w']), Dim.named(c=2, h=2, w=2).impose_order(['c', 'h', 'w']) ] params = ConcatParameters("test", axis=0) out_dims = params.get_output_size(in_dims) output_ = concat(params, in_dims, out_dims[0], inputs) assert isinstance(output_, np.ndarray) and np.array_equal( output_, np.concatenate(inputs, 0))
def rgb565_rgb888(input_tensor: np.ndarray, in_dim: Dim, out_dim: Dim): assert in_dim.is_named and in_dim.c == 1 and out_dim.is_named and out_dim.c == 3 input_tensor = np.repeat(input_tensor.transpose( in_dim.transpose_to_order(("h", "w", "c"))), 3, axis=2) input_tensor[:, :, 1] = (input_tensor[:, :, 0] & (63 << 5)) >> 3 input_tensor[:, :, 2] = (input_tensor[:, :, 0] & 31) << 3 input_tensor[:, :, 0] = (input_tensor[:, :, 0] & (31 << 11)) >> 8 return input_tensor.astype(np.uint8).transpose( out_dim.transpose_from_order(("h", "w", "c")))
def __init__(self, node, in_shape=None, out_shape=None, **kwargs) -> None: super(InsertReshapeAction, self).__init__(node, **kwargs) assert in_shape is not None and out_shape is not None, 'find test' if isinstance(in_shape, (list, tuple)): self.in_shape = Dim.unnamed(in_shape) else: self.in_shape = in_shape.clone() if in_shape is not None else None if isinstance(out_shape, (list, tuple)): self.out_shape = Dim.unnamed(out_shape) else: self.out_shape = out_shape.clone( ) if out_shape is not None else None
def test_paddim(): dim1 = PadDim(1) assert not dim1.is_same assert dim1.h == 2 and dim1.w == 2 assert dim1.l == 1 and dim1.r == 1 and dim1.t == 1 and dim1.b == 1 assert dim1.numpy_pad_shape(Dim.named_ordered(w=10, h=10)) == [(1, 1), (1, 1)] stride_dim = StrideDim(1) filt_dim = Conv2DFilterDim(5, 5, 1, 1) in_dim = Dim.named_ordered(c=1, h=20, w=20) dim1 = PadDim.same() dim1.calculate_same(in_dim, filt_dim, stride_dim) assert dim1.shape == [2, 2, 2, 2]
def test_concat_q(): in_q = QType(16, 1, True) inputs = [ in_q.quantize(np.full([1, 2, 2], 1.0)), in_q.quantize(np.full([2, 2, 2], 2.0)) ] in_dims = [ Dim.named(c=1, h=2, w=2).impose_order(['c', 'h', 'w']), Dim.named(c=1, h=2, w=2).impose_order(['c', 'h', 'w']) ] params = ConcatParameters("test", axis=0) out_dims = params.get_output_size(in_dims) output_ = concat(params, in_dims, out_dims[0], inputs) assert np.array_equal(output_, np.concatenate(inputs, 0))
def _execute(self, node, G): info(f"{self}") direction = self.direction if self.reshape_from is not None: params = ReshapeParameters(G.unique_name(f'{node.name}_reshape'), old_shape=Dim.unnamed( self.reshape_from), shape=Dim.reshape_to(self.reshape_to)) self.do_insert(node, G, params, direction=direction) node = params direction = "out" params = TransposeParameters(G.unique_name(f'{node.name}_trans'), transpose=self.transpose) self.do_insert(node, G, params, direction=direction)
def _common(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] valid_name = kwargs['valid_name'] inputs = [all_nodes[inp] for inp in node.input] x = inputs[0] x_shape = x[2].shape to_dtype = node.attrs['to'] if cls.is_constant(x): x_val = cls.get_constant(x) x_val = x_val.astype(to_dtype) if x_val.size < 10: logger.info("reducing %s to a constant %s", valid_name, x_val) else: logger.info("reducing %s to a constant", valid_name) params = ConstantInputParameters(valid_name, dims=Dim.unnamed(x_val.shape), value=x_val) else: params = QuantizeParameters(valid_name, to_qtype=QType(dtype=to_dtype)) G.add_edge( NNEdge(from_node=x[0], to_node=params, from_idx=x[1], to_idx=0)) all_nodes[node.output[0]] = (params, 0, ProvisionalDim(x_shape), None) return params
def get_output_size(self, in_dims): if self.indicated_outputs: return self.indicated_outputs self.in_dims = self.clone_dim_with_hints(in_dims) if len(self.in_dims) == 1: return [self.in_dims[0]] return [Dim.unknown()]
def __init__(self, *args, old_shape=None, shape=None, **kwargs): super(ReshapeParameters, self).__init__( *args, eliminate_transposes_pass_down=True, eliminate_transposes_pass_up=True, **kwargs) if not isinstance(shape, Dim): shape = Dim.unnamed(shape) self._shape = shape self._old_shape = old_shape
def get_all_output_dims(subgraph, elem, order=None): outputs = [] for idx in range(elem.OutputsLength()): tf_idx = elem.Outputs(idx) outputs.append( Dim.unnamed(remove_batch_dim(get_shape(subgraph, tf_idx, order)))) return outputs
def _common(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] opts = kwargs['opts'] qrec_class = kwargs.get('qrec_class') params_args = kwargs.get('params_args', {}) constant_operation = kwargs.get('constant_operation') inputs = [all_nodes[inp] for inp in node.input] assert len(inputs) == 2 if all(cls.is_constant(inp) for inp in inputs) and constant_operation: LOG.info("reducing %s to a constant", node.name) values = [cls.get_constant(inp) for inp in inputs] output_shapes = cls.implied_broadcast(inputs) params = ConstantInputParameters(node.name, value=constant_operation(*values), dims=Dim.unnamed(output_shapes[0].known_shape), constant_store=G.constant_store) else: params = kwargs['params_class'](node.name, **params_args) output_shapes = cls.implied_broadcast(inputs) shapes = [] for idx, inp in enumerate(inputs): G.add_edge(NNEdge(from_node=inp[0], to_node=params, from_idx=inp[1], to_idx=idx)) shapes.append(inp[2].known_shape) if isinstance(params, Broadcastable): params.set_broadcast(shapes) if opts.get('load_quantization'): G.quantization[NodeId(params)] = cls.load_tf_quantization( node.input, node.output, qrec_class=qrec_class) all_nodes[node.output[0]] = (params, 0, output_shapes[0]) return params
def _fix_constant_inputs(cls, inputs, shape): #TODO - This should be checked again # this fixes constant inputs to the broadcasted shape # this may not be a good thing to do if the input is connected to more than one node # since the shape change could cause problems # Two possible solutions: # 1) insert a rehape in between the constant and the broadcasted node # 2) make the broadcast node adjust more complete none_axes = tuple( [idx for idx, dim in enumerate(shape) if dim is None]) const_inputs = list([ inp for inp in inputs if isinstance(inp[0], ConstantInputParameters) ]) if not const_inputs: return for inp in const_inputs: node = inp[0] node.value = np.reshape(node.value, [1] * (len(shape) - len(node.value.shape)) + list(node.value.shape)) if none_axes: node.value = np.squeeze(node.value, axis=none_axes) # setting the provisional shape here is a little dangerous # if the unknown axis is first then it works but if it is # in the middle and this value is connected to another node then # could be problematic (but it is problematic anyway since it won't # expect something broadcasted) inp[2].shape = list(node.value.shape) node.dims = Dim.unnamed(node.value.shape)
def two_conv_graph(): G = NNGraph(name='two_conv_graph') ti = G.add_input(Dim.unnamed([10, 10, 2])) c1filt = Conv2DFilterDim(3, 3, 2, in_c=2) c1filt.impose_order(['out_c', 'h', 'w', 'in_c']) n1 = Conv2DParameters("node1", filt=c1filt, stride=StrideDim(1, 1), padding=PadDim(0), in_dims_hint=SparseList([['h', 'w', 'c']]), out_dims_hint=SparseList([['h', 'w', 'c']])) G.add_node(n1) w1 = [[0.25, 0.25], [0.25, 0.25], [0.25, 0.25]] w1 = [w1, w1, w1] w2 = [[0.75, 0.75], [0.75, 0.75], [0.75, 0.75]] w2 = [w2, w2, w2] n1.weights = np.array([w1, w2]) c2filt = Conv2DFilterDim(3, 3, 2, in_c=2) c2filt.impose_order(['out_c', 'h', 'w', 'in_c']) n2 = Conv2DParameters("node2", filt=c2filt, stride=StrideDim(1, 1), padding=PadDim(0), in_dims_hint=SparseList([['h', 'w', 'c']]), out_dims_hint=SparseList([['h', 'w', 'c']])) G.add_node(n2) w3 = [[0.75, 0.25], [0.75, 0.25], [0.75, 0.25]] w3 = [w3, w3, w3] n2.weights = np.array([w3, w3]) to = G.add_output() G.add_edge(NNEdge(ti, n1)) G.add_edge(NNEdge(n1, n2)) G.add_edge(NNEdge(n2, to)) G.add_dimensions() yield G
def __init__(self, *args, adjust_transpose=None, is_mutated=False, is_intermediate=False, always_copy=False, value: np.ndarray = None, qtype: QType = None, dims: Dim = None, **kwargs): if dims is None: dims = Dim.unnamed(value.shape) super(ConstantInputParameters, self).__init__(*args, dims=dims, **kwargs) self._value = value del self.at_options.valid_options['FIXED_ORDER'] self.at_options.valid_options['RESET_NAME'] = str self._adjust_transpose = adjust_transpose self._is_mutated = is_mutated self._is_intermediate = is_intermediate self._is_constant = True self._is_global = True self._always_copy = always_copy self._use_fake = False self._use_compressed = False self._compressed_value = None self._qtype = qtype
def _common(cls, node: TFLiteNode, **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] inputs = [all_nodes[t] for t in node.input] x = inputs[0] x_shape = x[2].shape if len(x_shape) != 1: raise ValueError(f'FILL {node.name} expecting 1D tensor for shape') shape = list(cls._verify_constant(inputs[0])) if cls._is_constant(inputs[1]): val = cls._get_constant(inputs[1]) params = ConstantInputParameters(node.name, dims=Dim.unnamed(shape), value=np.full(shape, val), constant_store=G.constant_store) all_nodes[node.output[0]] = (params, 0, ProvisionalDim(shape)) return params else: raise ValueError( f'FILL {node.name} non constant fill values are not currently supported' )
def _common(cls, node, **kwargs): all_nodes = kwargs['all_nodes'] G = kwargs['G'] valid_name = kwargs['valid_name'] if node.attrs.get('value'): value = numpy_helper.to_array(node.attrs['value']) elif node.attrs.get('value_float'): value = np.atleast_1d(node.attrs['value_float'], dtype=np.float32) elif node.attrs.get('value_floats'): value = np.array(node.attrs['value_floats'], dtype=np.float32) elif node.attrs.get('value_int'): value = np.atleast_1d(node.attrs['value_int'], dtype=np.int32) elif node.attrs.get('value_ints'): value = np.array(node.attrs['value_ints'], dtype=np.int32) elif node.attrs.get('value_string') or node.attrs.get('value_strings'): raise NotImplementedError( 'NNTOOL does not support string constants') elif node.attrs.get('sparse_value'): raise NotImplementedError( 'NNTOOL does not support sparse constants') else: raise ValueError('ONNX constant has no value') params = ConstantInputParameters(valid_name, dims=Dim.unnamed(value.shape), value=value) all_nodes[node.output[0]] = (params, 0, ProvisionalDim(value.shape), None) return params
def test_creation5(): dim1 = Dim.named_ordered(a=1, c=3, b=2) assert not dim1.is_unknown assert dim1.is_named assert dim1.is_ordered assert dim1.a == 1 and dim1.b == 2 and dim1.c == 3 assert dim1.shape == [1, 3, 2]
def match(self, G: GraphView, set_identity: bool = True): has_modified = False for node in G.nodes(node_classes=ConstantInputParameters): out_edges = G.out_edges(node.name) if len(out_edges) <= 1: continue has_modified = True LOG.info( 'node %s has more than one out edge and will be duplicated', node.name) idx = 1 for out_edge in out_edges[1::]: new_constant = ConstantInputParameters(f'{node.name}_{idx}', dims=Dim.unnamed( node.dims.shape), value=node.value.copy()) G.remove_edge(out_edge) G.add_edge( NNEdge(from_node=new_constant, to_node=out_edge.to_node, to_idx=out_edge.to_idx)) idx += 1 if set_identity: self.set_identity(G) return has_modified