def reverse_matmul(G: GraphView, params):
    # reverse edges
    in_edges = G.indexed_in_edges(params.name)
    for edge in in_edges[0:2:]:
        G.remove_edge(edge)
    other_idx = 1
    for edge in in_edges[0:2:]:
        G.add_edge(
            NNEdge(from_node=edge.from_node,
                   to_node=params,
                   from_idx=edge.from_idx,
                   to_idx=other_idx))
        other_idx = 1 - other_idx
    nid = NodeId(params)
    if G.quantization and nid in G.quantization:
        qrec = G.quantization[nid]
        # swap qrecs
        qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0]

    # add transposes
    in_nodes = []
    for idx in range(2):
        tin_params = TransposeParameters(
            G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0))
        in_nodes.append(tin_params)
        G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge)
    tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"),
                                      transpose=(1, 0))
    G.insert_node_after(params, tout_params)
    return in_nodes, tout_params
示例#2
0
    def move_constant(cls, G: GraphView, params, in_qs):
        # looks for a constant on one of the inputs
        # if there is one we can scale by the second dimension of the second
        # tensor. If the constant is on the first tensor then move to the second
        # and transpose the operation
        in_edges = G.indexed_in_edges(params.name)
        in1_node = in_edges[0].from_node
        in2_node = in_edges[1].from_node

        if isinstance(in2_node, ConstantInputParameters):
            return in2_node, in_qs
        elif isinstance(in1_node, ConstantInputParameters):
            if len(params.in_dims) > 2:
                # check if the bias has the correct length to move constant
                # it must have a length equal to the second tensors second dimension after transpose
                bias_size = params.in_dims[2].size()
                in1_shape = params.in_dims[0].shape
                if in1_shape[1] != bias_size:
                    return None, in_qs
            for edge in in_edges[:2:]:
                G.remove_edge(edge)
            to_idx = 1
            # swap edges to move constant onto input 2
            for edge in in_edges[:2:]:
                new_edge = NNEdge(from_node=edge.from_node,
                                  to_node=edge.to_node,
                                  from_idx=edge.from_idx,
                                  to_idx=to_idx)
                G.add_edge(new_edge)
                to_idx = 1 - to_idx
            # use A.B = (BT.AT)T identity
            tin1 = TransposeParameters(G.unique_name(f'{params.name}_tin1'),
                                       transpose=(1, 0))
            tin2 = TransposeParameters(G.unique_name(f'{params.name}_tin2'),
                                       transpose=(1, 0))
            tout = TransposeParameters(G.unique_name(f'{params.name}_tout'),
                                       transpose=(1, 0))
            G.insert_node_before(tin1, params)
            G.insert_node_before(tin2, params, to_idx=1)
            G.insert_node_after(params, tout)
            LOG.warning('transposes inserted on %s - rerun adjust',
                        params.name)
            return in1_node, [in_qs[1], in_qs[0]] + in_qs[2::]
        else:
            return None, in_qs
示例#3
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = False
        concats = set(G.nodes(node_classes=ConcatParameters))
        while concats:
            concat = concats.pop()
            if concat.axis != 0:
                continue
            subgraph = find_concats_up(G, concat)
            found = set(subgraph.nodes(node_classes=ConcatParameters))
            if len(found) <= 1:
                continue
            LOG.info(
                f"Combining concats {','.join([node.name for node in found])}")
            modified_graph = True
            concats -= found

            in_edges = [inp.edge for inp in subgraph.inputs()]
            in_dims = [
                edge.from_node.out_dims[edge.from_idx] for edge in in_edges
            ]
            nodes_to_remove = [
                node for node in subgraph.nodes()
                if node != concat and not isinstance(node, DummyInput)
            ]
            for edge in in_edges:
                G.remove_edge(edge)
            for node in nodes_to_remove:
                if node.name in G:
                    G.remove(node)
                nid = NodeId(node)
                if G.quantization and nid in G.quantization:
                    del G.quantization[nid]

            # remove_internal_graph(G, subgraph)
            out_dim = concat.out_dims[0]
            in_qs = []
            for idx, edge in enumerate(in_edges):
                from_node = edge.from_node
                from_idx = edge.from_idx
                if len(in_dims[idx]) > 1:
                    reshape = ReshapeParameters(
                        G.unique_name(f'{concat.name}_flat{idx}'),
                        old_shape=in_dims[idx],
                        shape=Dim.unnamed([in_dims[idx].size()]))
                    G.add_edge(
                        NNEdge(from_node=from_node,
                               from_idx=from_idx,
                               to_node=reshape))
                    from_node = reshape
                    from_idx = 0
                G.add_edge(
                    NNEdge(from_node=from_node,
                           from_idx=from_idx,
                           to_node=concat,
                           to_idx=idx))
                if in_qs is not None and G.quantization:
                    nid = NodeId(edge.from_node)
                    if nid in G.quantization:
                        qrec = G.quantization[nid]
                        in_qs.append(qrec.out_qs[edge.from_idx])
                    else:
                        in_qs = None
                else:
                    in_qs = None
            if in_qs is not None and G.quantization:
                nid = NodeId(concat)
                if nid in G.quantization:
                    G.quantization[nid].in_qs = in_qs
            reshape = ReshapeParameters(G.unique_name(f'{concat.name}_expand'),
                                        old_shape=Dim.unnamed([out_dim.size()
                                                               ]),
                                        shape=out_dim)
            G.insert_node_after(concat, reshape, edge_class=NNEdge)

        if set_identity:
            self.set_identity(G)

        return modified_graph