def reverse_matmul(G: GraphView, params): # reverse edges in_edges = G.indexed_in_edges(params.name) for edge in in_edges[0:2:]: G.remove_edge(edge) other_idx = 1 for edge in in_edges[0:2:]: G.add_edge( NNEdge(from_node=edge.from_node, to_node=params, from_idx=edge.from_idx, to_idx=other_idx)) other_idx = 1 - other_idx nid = NodeId(params) if G.quantization and nid in G.quantization: qrec = G.quantization[nid] # swap qrecs qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0] # add transposes in_nodes = [] for idx in range(2): tin_params = TransposeParameters( G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0)) in_nodes.append(tin_params) G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge) tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"), transpose=(1, 0)) G.insert_node_after(params, tout_params) return in_nodes, tout_params
def move_constant(cls, G: GraphView, params, in_qs): # looks for a constant on one of the inputs # if there is one we can scale by the second dimension of the second # tensor. If the constant is on the first tensor then move to the second # and transpose the operation in_edges = G.indexed_in_edges(params.name) in1_node = in_edges[0].from_node in2_node = in_edges[1].from_node if isinstance(in2_node, ConstantInputParameters): return in2_node, in_qs elif isinstance(in1_node, ConstantInputParameters): if len(params.in_dims) > 2: # check if the bias has the correct length to move constant # it must have a length equal to the second tensors second dimension after transpose bias_size = params.in_dims[2].size() in1_shape = params.in_dims[0].shape if in1_shape[1] != bias_size: return None, in_qs for edge in in_edges[:2:]: G.remove_edge(edge) to_idx = 1 # swap edges to move constant onto input 2 for edge in in_edges[:2:]: new_edge = NNEdge(from_node=edge.from_node, to_node=edge.to_node, from_idx=edge.from_idx, to_idx=to_idx) G.add_edge(new_edge) to_idx = 1 - to_idx # use A.B = (BT.AT)T identity tin1 = TransposeParameters(G.unique_name(f'{params.name}_tin1'), transpose=(1, 0)) tin2 = TransposeParameters(G.unique_name(f'{params.name}_tin2'), transpose=(1, 0)) tout = TransposeParameters(G.unique_name(f'{params.name}_tout'), transpose=(1, 0)) G.insert_node_before(tin1, params) G.insert_node_before(tin2, params, to_idx=1) G.insert_node_after(params, tout) LOG.warning('transposes inserted on %s - rerun adjust', params.name) return in1_node, [in_qs[1], in_qs[0]] + in_qs[2::] else: return None, in_qs
def _match(self, G: GraphView, set_identity: bool = True, **kwargs): modified_graph = False concats = set(G.nodes(node_classes=ConcatParameters)) while concats: concat = concats.pop() if concat.axis != 0: continue subgraph = find_concats_up(G, concat) found = set(subgraph.nodes(node_classes=ConcatParameters)) if len(found) <= 1: continue LOG.info( f"Combining concats {','.join([node.name for node in found])}") modified_graph = True concats -= found in_edges = [inp.edge for inp in subgraph.inputs()] in_dims = [ edge.from_node.out_dims[edge.from_idx] for edge in in_edges ] nodes_to_remove = [ node for node in subgraph.nodes() if node != concat and not isinstance(node, DummyInput) ] for edge in in_edges: G.remove_edge(edge) for node in nodes_to_remove: if node.name in G: G.remove(node) nid = NodeId(node) if G.quantization and nid in G.quantization: del G.quantization[nid] # remove_internal_graph(G, subgraph) out_dim = concat.out_dims[0] in_qs = [] for idx, edge in enumerate(in_edges): from_node = edge.from_node from_idx = edge.from_idx if len(in_dims[idx]) > 1: reshape = ReshapeParameters( G.unique_name(f'{concat.name}_flat{idx}'), old_shape=in_dims[idx], shape=Dim.unnamed([in_dims[idx].size()])) G.add_edge( NNEdge(from_node=from_node, from_idx=from_idx, to_node=reshape)) from_node = reshape from_idx = 0 G.add_edge( NNEdge(from_node=from_node, from_idx=from_idx, to_node=concat, to_idx=idx)) if in_qs is not None and G.quantization: nid = NodeId(edge.from_node) if nid in G.quantization: qrec = G.quantization[nid] in_qs.append(qrec.out_qs[edge.from_idx]) else: in_qs = None else: in_qs = None if in_qs is not None and G.quantization: nid = NodeId(concat) if nid in G.quantization: G.quantization[nid].in_qs = in_qs reshape = ReshapeParameters(G.unique_name(f'{concat.name}_expand'), old_shape=Dim.unnamed([out_dim.size() ]), shape=out_dim) G.insert_node_after(concat, reshape, edge_class=NNEdge) if set_identity: self.set_identity(G) return modified_graph