def match(self, G: GraphView, set_identity: bool = True): # Note: assumption is that dimensions are valid when a match is called dimensions_set = True for match_instance in self.matches: if match_instance.NEEDS_VALID_DIMENSION and not dimensions_set: G.add_dimensions() dimensions_set = True has_modified_graph = match_instance.match(G, set_identity=False) if dimensions_set and has_modified_graph: dimensions_set = False if set_identity: self.set_identity(G)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): # Note: assumption is that dimensions are valid when a match is called found_match = True dimensions_set = True while found_match: found_match = False for match_instance in self.matches: LOG.debug("fusions - start %s", match_instance.name) if match_instance.NEEDS_VALID_DIMENSION and not dimensions_set: G.add_dimensions(quiet=True) dimensions_set = True has_modified_graph = match_instance.match( G, set_identity=False, group_identity=self._identity) if has_modified_graph: LOG.info("++ fusion %s modified graph", match_instance.name) found_match = True G.add_dimensions(quiet=True) if dimensions_set and has_modified_graph: dimensions_set = False if set_identity: self.set_identity(G)
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): has_modified_graph = False has_transposed = False for params in G.nodes(node_classes=MatMulOpParameters): while True: out_edges = G.out_edges(params.name) # can't fuse if there is a branch if len(out_edges) > 1: break out_edge = out_edges[0] op_node = out_edge.to_node # must be a valid matrix op if not isinstance(op_node, (MatrixAddParameters, MatrixMulParameters)): break # other edge to the op must be a constant other_idx = 1 if out_edge.to_idx == 0 else 0 other_in_edge = G.indexed_in_edges(op_node.name)[other_idx] if not isinstance(other_in_edge.from_node, ConstantInputParameters): break const_node = other_in_edge.from_node remove_constant = len(G.out_edges(const_node.name)) flat_value = const_node.dqvalue.flatten() out_shape = params.out_dims[0].shape if len(out_shape) != 2: raise ValueError( f'strange outputs shape of {out_shape} for matmul {params.name}' ) if len(flat_value) != out_shape[0] and len( flat_value) != out_shape[1]: LOG.info( "can't fuse %s into %s - value shape is not correct for bias", const_node.name, params.name) break has_bias = len(params.in_dims) == 3 if isinstance(op_node, MatrixAddParameters): if has_bias: if len(flat_value.shape) != len(params.in_dims[2]): LOG.info( "can't fuse %s into %s - bias shape is not the same", const_node.name, params.name) break bias_node = G.indexed_in_edges( params.name)[2].from_node LOG.info( "folding additive bias from %s into existing bias on %s", op_node.name, params.name) bias_node.value = bias_node.dq_value + flat_value else: if len(flat_value) == out_shape[1]: # matmul needs to be transposed to fuse this reverse_matmul(G, params) has_transposed = True bias_node = ConstantInputParameters( G.unique_name(f'{params.name}_bias'), value=flat_value, dims=Dim.unnamed(flat_value.shape)) G.add_edge( NNEdge(from_node=bias_node, to_node=params, to_idx=2)) # extend the inward transpose if params.transpose_in: params.transpose_in = params.transpose_in + [None] LOG.info( "folding additive bias from %s into new bias on %s", op_node.name, params.name) else: params_in = G.indexed_in_edges(params.name) consts = [ isinstance(edge.from_node, ConstantInputParameters) for edge in params_in ] if not any(consts): break mult_const_node = params_in[1].from_node if consts[ 1] else params_in[0].from_node mult_const_node.value = mult_const_node.dqvalue * const_node.dqvalue if has_bias: bias_node = params_in[2].from_node bias_node.value = bias_node.dqvalue * const_node.dqvalue LOG.info( "folding multaplicative bias from %s into new bias on %s", op_node.name, params.name) out_edges = G.out_edges(op_node.name) G.remove(op_node) if remove_constant: G.remove(const_node) for edge in out_edges: G.add_edge( NNEdge(from_node=params, to_node=edge.to_node, to_idx=edge.to_idx)) G.add_dimensions() if G.quantization: quantizer = UnifiedQuantizer.from_quantized_graph(G) quantizer.quantize(G, start_nodes=[params]) RemoveUnnecessaryQuantizeOperators().match(G) if has_transposed: G.adjust_order() if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False slices_by_origin = {} for slice_node in [ node for node in G.nodes() if isinstance(node, StridedSliceParameters) ]: in_edge = G.in_edges(slice_node.name)[0] group = slices_by_origin.setdefault( (in_edge.from_node, in_edge.from_idx), []) group.append(slice_node) for in_edge, slice_nodes in slices_by_origin.items(): slices = list(zip(*[node.act_slice for node in slice_nodes])) if len(slice_nodes) == 1: self.slice_to_split(G, slice_nodes, slices) continue # strides must be one if any(sl[2] != 1 for sl_axis in slices for sl in sl_axis): continue diff_axes = list([ idx for idx, elems in enumerate(slices) if not all(elems[0] == elem for elem in elems[1::]) ]) not_diff_axes = [ idx for idx in range(len(slices)) if idx not in diff_axes ] diff_slices = [ sl for idx, sl in enumerate(slices) if idx in diff_axes ] axis_lengths = in_edge[0].out_dims[in_edge[1]].shape if not_diff_axes and min(not_diff_axes) < max(diff_axes): transpose_from = tuple(range(len(slices))) transpose_to = tuple(diff_axes + not_diff_axes) axis_lengths = [axis_lengths[idx] for idx in transpose_to] else: transpose_from = transpose_to = None diff_axis_lengths = axis_lengths[0:len(diff_axes):] diff_slices = combine_slices(diff_axis_lengths, diff_slices, slice_nodes) if diff_slices is None: continue if len(diff_axes) > 1: reshape_from = axis_lengths reshape_to = [np.prod(diff_axis_lengths)] + \ axis_lengths[len(diff_axes)::] else: reshape_from = None reshape_to = slice_nodes[0].in_dims[0].shape if transpose_from: reshape_to = [reshape_to[idx] for idx in transpose_to] sizes, shapes, sorted_nodes = slices_to_sizes( diff_slices, axis_lengths[len(diff_axes)::]) name_prefix = sorted_nodes[0].name in_edge = G.in_edges(sorted_nodes[0].name)[0] in_node = in_edge.from_node in_idx = in_edge.from_idx if transpose_from: params = TransposeParameters(G.unique_name(name_prefix + '_tin'), transpose=transpose_to) G.add_edge( NNEdge(from_node=in_node, to_node=params, from_idx=in_idx)) in_node = params in_idx = 0 if reshape_from: params = ReshapeParameters(G.unique_name(name_prefix + '_reshape'), old_shape=Dim.unnamed(reshape_from), shape=Dim.unnamed(reshape_to)) G.add_edge( NNEdge(from_node=in_node, to_node=params, from_idx=in_idx)) in_node = params in_idx = 0 act_slices, out_shapes, axis = SplitParameters.get_splits( reshape_to, 0, splits=sizes) split_node = SplitParameters(G.unique_name(name_prefix + '_split'), act_slices=act_slices, out_shapes=out_shapes, axis=axis) G.add_edge( NNEdge(from_node=in_node, from_idx=in_idx, to_node=split_node)) sub_names = [] for idx, node in enumerate(sorted_nodes): sub_names.append(node.name) out_edges = G.out_edges(node.name) G.remove(node) for out_edge in out_edges: params = split_node out_idx = idx if reshape_from: from_node = params params = ReshapeParameters( G.unique_name(name_prefix + f'_reshape{idx}'), shape=Dim.unnamed(shapes[idx])) G.add_edge( NNEdge(from_node=from_node, to_node=params, from_idx=out_idx)) out_idx = 0 if transpose_from: from_node = params params = TransposeParameters( G.unique_name(name_prefix + f'_tout{idx}'), transpose=reverse_transpose(transpose_to)) G.add_edge( NNEdge(from_node=from_node, to_node=params, from_idx=out_idx)) out_idx = 0 G.add_edge( NNEdge(from_node=params, to_node=out_edge.to_node, from_idx=out_idx, to_idx=out_edge.to_idx)) if G.quantization: G.add_dimensions() quantizer = NewQuantizer.from_quantized_graph(G) quantizer.quantize() RemoveUnnecessaryQuantizeOperators().match(G) LOG.info( f'replaced slice nodes {",".join(sub_names)} with split node {split_node.name}' ) has_modified_graph = True if set_identity: self.set_identity(G) return has_modified_graph
def _match(self, G: GraphView, set_identity: bool = True, **kwargs) -> bool: has_modified_graph = False slices_by_origin = {} for slice_node in [ node for node in G.nodes() if isinstance(node, StridedSliceParameters) ]: in_edge = G.in_edges(slice_node.name)[0] group = slices_by_origin.setdefault( (in_edge.from_node, in_edge.from_idx), []) group.append(slice_node) for in_edge, slice_nodes in slices_by_origin.items(): slices = list(zip(*[node.act_slice for node in slice_nodes])) if len(slice_nodes) == 1: self.slice_to_split(G, slice_nodes, slices) continue diff_slices = [(idx, elems) for idx, elems in enumerate(slices) if not all(elems[0] == elem for elem in elems[1::])] if len(diff_slices) != 1: continue # strides must be one if any(sl[2] != 1 for sl in diff_slices[0][1]): continue # check if slices are consecutive and non overlapping slices = sorted(diff_slices[0][1], key=lambda x: x[0]) if not all(sl[0] + sl[1] == slices[i + 1][0] for i, sl in enumerate(slices[:-1:])): continue szes = [sl[1] - sl[0] for sl in slices] axis = diff_slices[0][0] slice_nodes = sorted(slice_nodes, key=lambda x: x.act_slice[axis][0]) act_slices, out_shapes, axis = SplitParameters.get_splits( slice_nodes[0].in_dims[0].shape, axis, splits=szes) params = SplitParameters(slice_nodes[0].name + '_split', act_slices=act_slices, out_shapes=out_shapes, axis=axis) in_edge = G.in_edges(slice_nodes[0].name)[0] G.add_edge( NNEdge(from_node=in_edge.from_node, to_node=params, from_idx=in_edge.from_idx)) sub_names = [] for idx, node in enumerate(slice_nodes): sub_names.append(node.name) out_edges = G.out_edges(node.name) G.remove(node) for out_edge in out_edges: G.add_edge( NNEdge(from_node=params, to_node=out_edge.to_node, from_idx=idx, to_idx=out_edge.to_idx)) if G.quantization: G.add_dimensions() quantizer = UnifiedQuantizer.from_quantized_graph(G) quantizer.quantize(G, start_nodes=[params]) RemoveUnnecessaryQuantizeOperators().match(G) LOG.info( f'replaced slice nodes {",".join(sub_names)} with split node {sub_names[0]}' ) has_modified_graph = True if set_identity: self.set_identity(G) return has_modified_graph