def reverse_matmul(G: GraphView, params): # reverse edges in_edges = G.indexed_in_edges(params.name) for edge in in_edges[0:2:]: G.remove_edge(edge) other_idx = 1 for edge in in_edges[0:2:]: G.add_edge( NNEdge(from_node=edge.from_node, to_node=params, from_idx=edge.from_idx, to_idx=other_idx)) other_idx = 1 - other_idx nid = NodeId(params) if G.quantization and nid in G.quantization: qrec = G.quantization[nid] # swap qrecs qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0] # add transposes in_nodes = [] for idx in range(2): tin_params = TransposeParameters( G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0)) in_nodes.append(tin_params) G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge) tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"), transpose=(1, 0)) G.insert_node_after(params, tout_params) return in_nodes, tout_params
def move_constant(cls, G: GraphView, params, in_qs): # looks for a constant on one of the inputs # if there is one we can scale by the second dimension of the second # tensor. If the constant is on the first tensor then move to the second # and transpose the operation in_edges = G.indexed_in_edges(params.name) in1_node = in_edges[0].from_node in2_node = in_edges[1].from_node if isinstance(in2_node, ConstantInputParameters): return in2_node, in_qs elif isinstance(in1_node, ConstantInputParameters): if len(params.in_dims) > 2: # check if the bias has the correct length to move constant # it must have a length equal to the second tensors second dimension after transpose bias_size = params.in_dims[2].size() in1_shape = params.in_dims[0].shape if in1_shape[1] != bias_size: return None, in_qs for edge in in_edges[:2:]: G.remove_edge(edge) to_idx = 1 # swap edges to move constant onto input 2 for edge in in_edges[:2:]: new_edge = NNEdge(from_node=edge.from_node, to_node=edge.to_node, from_idx=edge.from_idx, to_idx=to_idx) G.add_edge(new_edge) to_idx = 1 - to_idx # use A.B = (BT.AT)T identity tin1 = TransposeParameters(G.unique_name(f'{params.name}_tin1'), transpose=(1, 0)) tin2 = TransposeParameters(G.unique_name(f'{params.name}_tin2'), transpose=(1, 0)) tout = TransposeParameters(G.unique_name(f'{params.name}_tout'), transpose=(1, 0)) G.insert_node_before(tin1, params) G.insert_node_before(tin2, params, to_idx=1) G.insert_node_after(params, tout) LOG.warning('transposes inserted on %s - rerun adjust', params.name) return in1_node, [in_qs[1], in_qs[0]] + in_qs[2::] else: return None, in_qs