Ejemplo n.º 1
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False

        for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())):
            node_list = self.get_node_list(G, node,
                                           FusionMatch(self._default_ktype))
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for snode in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=snode))
                last_node = snode
            # assumption here is that the first node could have multiple inputs but definitely has only
            # one output
            input_mapping = [[
                (node_list.node, idx)
            ] for idx in range(G.num_in_edges(node_list.node.name))]
            output_mapping = [(last_node, 0)]
            pnode = node_list.fusions_class(node_list.node.name + '_fusion',
                                            fusion_type=node_list.fusion_type,
                                            subgraph=subgraph,
                                            input_mapping=input_mapping,
                                            output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=qrecs[0].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for fnode in pnode.contained_nodes():
                        G.quantization.move_to_fusion(fnode, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.node.name)
            out_edges = G.out_edges(last_node.name)
            for snode in node_list.order:
                G.remove(snode)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Ejemplo n.º 2
0
def walk_down(subg: GraphView, node: Parameters,
              node_slices: Mapping[Parameters, Sequence[SlicedTensor]]):

    # edges not created
    if isinstance(node, FusionInputParameters):
        inp_slice = InputSlice.from_shape(node, node.dims.shape)
        dim_slices = node_slices[node] = [
            SlicedTensor([
                SliceElement(tuple([0] * inp_slice.rank), inp_slice.shape,
                             inp_slice)
            ])
        ]
    else:
        dim_slices = node_slices.get(node)
        if dim_slices is None:
            return
        # all edges not created
        if len(dim_slices) < subg.num_in_edges(node.name):
            return
        # all edges not created
        if any(val is None for val in dim_slices):
            return
        if isinstance(node, Transposable) and node.transpose_in:
            for idx, transpose in enumerate(node.transpose_in):
                if transpose:
                    dim_slices[idx] = dim_slices[idx].transpose(transpose)
        if isinstance(node, ConcatParameters):
            dim_slices = [SlicedTensor.concat(*dim_slices, axis=node.axis)]
        elif isinstance(node, SplitParameters):
            dim_slices = dim_slices[0].split(node.act_slices)
        elif isinstance(node, StridedSliceParameters):
            dim_slices = [dim_slices[0].slice(node.act_slice)]

    if isinstance(node, Transposable) and node.transpose_out:
        for idx, transpose in enumerate(node.transpose_out):
            if transpose:
                dim_slices[idx] = dim_slices[idx].transpose(transpose)

    # set output edges
    for edge_set in subg.indexed_out_edges(node.name):
        for edge in edge_set:
            dest_slices = node_slices.setdefault(edge.to_node,
                                                 [None] * (edge.to_idx + 1))
            if len(dest_slices) < edge.to_idx + 1:
                dest_slices = dest_slices + \
                    ([None] * ((edge.to_idx + 1) - len(dest_slices)))
                node_slices[edge.to_node] = dest_slices
            dest_slices[edge.to_idx] = dim_slices[edge.from_idx]

    # explore graph
    for edge_set in subg.indexed_out_edges(node.name):
        for edge in edge_set:
            walk_down(subg, edge.to_node, node_slices)
Ejemplo n.º 3
0
 def _match(self, G: GraphView, node: Node, edge: Edge):
     return isinstance(node, ConcatParameters) and G.num_in_edges(
         node.name) == 1