def find_direct_connects(self,
                          G,
                          node,
                          has_modified_graph,
                          find_output=True):
     # traverse reshapes or transposes that do nothing - check gen
     # find edges connected to concats
     res = self.find_split_concat(G, node, find_output=find_output)
     if res is None:
         return has_modified_graph
     if G.quantization:
         qrec = G.quantization[NodeId(node)]
     for idx, bundle in enumerate(res):
         if not bundle:
             continue
         has_modified_graph = True
         copy_node = CopyParameters("%s_copy_%s" % (node.name, idx))
         for edge_set in bundle:
             first_edge = edge_set[0]
             G.remove_edge(first_edge)
             LOG.info('inserting copy between %s/%s and %s/%s', node.name,
                      idx, first_edge.to_node.name, first_edge.to_idx)
             G.add_edge(
                 NNEdge(copy_node,
                        first_edge.to_node,
                        to_idx=first_edge.to_idx))
         G.add_edge(NNEdge(node, copy_node, from_idx=idx))
         if G.quantization:
             G.quantization[NodeId(copy_node)] = QRec.copy_ktype(
                 qrec,
                 in_qs=[deepcopy(qrec.out_qs[idx])],
                 out_qs=[deepcopy(qrec.out_qs[idx])])
     return True
示例#2
0
    def replace_function(self, G: GraphView, subgraph: GraphView):
        relu_node = None
        constant_node = None
        mul_node = None
        for node in subgraph.nodes():
            if isinstance(node, ReluActivationParameters):
                relu_node = node
            elif isinstance(node, ConstantInputParameters):
                constant_node = node
            elif isinstance(node, MatrixMulParameters):
                mul_node = node

        activation = HSigmoidActivationParameters(mul_node.name +
                                                  "_fused_close_hsigmoid",
                                                  offset=0)

        if G.quantization:
            reluqrec = G.quantization[NodeId(relu_node)]
            mulqrec = G.quantization[NodeId(mul_node)]
            del G.quantization[NodeId(constant_node)]
            pqrec = QRec.copy_ktype(reluqrec,
                                    in_qs=reluqrec.in_qs,
                                    out_qs=mulqrec.out_qs)
            G.quantization[NodeId(activation)] = pqrec
        return activation, None, None
 def insert_copy_on_common_concat_in(self, G, concat_nodes):
     # in every concat nodes collect all the in edges (from_node, from_idx)
     # if there are repetition of tuples, insert a copy in every repetition
     # different concats cannot have the same in edge (from_node, from_idx)
     concat_in_edges = []
     has_modified_graph = False
     for concat_node in concat_nodes:
         for idx, in_edge in enumerate(G.indexed_in_edges(
                 concat_node.name)):
             real_in_edge = find_real_in_edge(G, in_edge)
             if real_in_edge in concat_in_edges:
                 has_modified_graph = True
                 copy_node = CopyParameters("%s_copy_%s" %
                                            (concat_node.name, idx))
                 G.remove_edge(in_edge)
                 LOG.info(
                     'common_concat: inserting copy between %s/%s and %s/%s',
                     in_edge.from_node.name, idx, concat_node.name,
                     in_edge.to_idx)
                 G.add_edge(
                     NNEdge(in_edge.from_node,
                            copy_node,
                            from_idx=in_edge.from_idx))
                 G.add_edge(
                     NNEdge(copy_node, concat_node, to_idx=in_edge.to_idx))
                 if G.quantization:
                     qrec = G.quantization[NodeId(concat_node)]
                     G.quantization[NodeId(copy_node)] = QRec.copy_ktype(
                         qrec,
                         in_qs=[deepcopy(qrec.in_qs[idx])],
                         out_qs=[deepcopy(qrec.in_qs[idx])])
             else:
                 concat_in_edges.append(real_in_edge)
     return has_modified_graph
示例#4
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False

        for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())):
            node_list = self.get_node_list(G, node,
                                           FusionMatch(self._default_ktype))
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for snode in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=snode))
                last_node = snode
            # assumption here is that the first node could have multiple inputs but definitely has only
            # one output
            input_mapping = [[
                (node_list.node, idx)
            ] for idx in range(G.num_in_edges(node_list.node.name))]
            output_mapping = [(last_node, 0)]
            pnode = node_list.fusions_class(node_list.node.name + '_fusion',
                                            fusion_type=node_list.fusion_type,
                                            subgraph=subgraph,
                                            input_mapping=input_mapping,
                                            output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=qrecs[0].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for fnode in pnode.contained_nodes():
                        G.quantization.move_to_fusion(fnode, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.node.name)
            out_edges = G.out_edges(last_node.name)
            for snode in node_list.order:
                G.remove(snode)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
示例#5
0
 def copy_qrec(self, from_node, from_dir, from_idx, to_node):
     from_qrec = self.qset.get(NodeId(from_node))
     if from_qrec is None:
         raise ValueError(
             f'trying to copy qrec from {from_node.name} to {to_node.name} - node has no qrec'
         )
     qtype = deepcopy(getattr(from_qrec, f'{from_dir}_qs')[from_idx])
     self.qset[NodeId(to_node)] = QRec.copy_ktype(from_qrec,
                                                  in_qs=[qtype],
                                                  out_qs=[qtype])
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
        for fc_node in [params for params in G.nodes() if isinstance(params, FcParameters)]:
            node_list = self.get_node_list(G, fc_node, valid_activations)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.linear, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = LinearFusionParameters(
                node_list.linear.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(
                        qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.linear.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode,
                                  from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node,
                                  from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        something_changed = False
        for relu_node in [node for node in G.nodes(node_classes=ReluActivationParameters) if node.upper_bound == 6]:
            out_edges = G.out_edges(relu_node)
            if len(out_edges) != 1 or not isinstance(out_edges[0].to_node, MatrixMulParameters):
                continue
            mul_node = out_edges[0].to_node
            in_edges = G.in_edges(mul_node)
            if len(in_edges) != 2:
                continue
            other_edge = (set(in_edges) - {out_edges[0]}).pop()
            constant_node = other_edge.from_node
            if len(G.out_edges(constant_node)) != 1:
                continue
            if (not isinstance(constant_node, ConstantInputParameters) or
                    not check_equals(G, constant_node, 1.0/6.0)):
                continue

            something_changed = True
            activation = HSigmoidActivationParameters(
                G.unique_name(f'{mul_node.name}_hsigmoid'), offset=0)

            in_edges = G.in_edges(relu_node)
            out_edges = G.out_edges(mul_node)

            nodes_to_replace = [relu_node, mul_node, constant_node]

            LOG.info(f'fusing {", ".join(node.name for node in nodes_to_replace)} into HSIGMOID {activation.name}')
            G.remove_all(nodes_to_replace)

            for in_edge in in_edges:
                G.add_edge(NNEdge.clone(in_edge, to_node=activation, to_idx=0))
            for out_edge in out_edges:
                G.add_edge(NNEdge.clone(
                    out_edge, from_node=activation, from_idx=0))

            if G.quantization:
                reluqrec = G.quantization[NodeId(relu_node)]
                mulqrec = G.quantization[NodeId(mul_node)]
                del G.quantization[NodeId(constant_node)]
                pqrec = QRec.copy_ktype(
                    reluqrec, in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs)
                G.quantization[NodeId(activation)] = pqrec

        return something_changed
示例#8
0
 def remove_known_batch_dimension(cls, G, x, node, batch_axis=0):
     x_shape = x[2].shape
     if x_shape[batch_axis] is not None:
         if x_shape[0] > 1:
             raise ValueError(
                 f'multi batch (n={x_shape[batch_axis]}) operations are not supported by {node.name}')
         rparams = ReshapeParameters(
             f'{node.name}_batch',
             old_shape=Dim.unnamed(x_shape),
             shape=Dim.unnamed(x_shape[0:batch_axis:]+x_shape[batch_axis+1::]))
         if G.quantization:
             qrec = G.quantization[NodeId(x[0])]
             G.quantization[NodeId(rparams)] = QRec.copy_ktype(
                 qrec,
                 in_qs=[qrec.out_qs[0]],
                 out_qs=[qrec.out_qs[0]])
         G.add_edge(
             NNEdge(from_node=x[0], to_node=rparams, from_idx=x[1], to_idx=0))
         return (rparams, 0, ProvisionalDim(x_shape[0:batch_axis:]+[None]+x_shape[batch_axis+1::]))
     else:
         return x
示例#9
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
            valid_activations_wo_pool = VALID_ACTIVATIONS_POW2_WO_POOL
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
            valid_activations_wo_pool = VALID_ACTIVATIONS_SQ8_WO_POOL
        for pool_node in G.nodes(node_classes=(PoolingParameters,
                                               GlobalPoolingParameters)):
            node_list = self.get_node_list(G, pool_node, valid_activations,
                                           valid_activations_wo_pool)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node,
                                             to_node=node))
                last_node = node
            input_mapping = [[(node_list.pool, 0)]]
            output_mapping = [(last_node, 0)]
            pnode = ActivationFusion(node_list.pool.name + '_fusion',
                                     fusion_type=node_list.fusion_type,
                                     subgraph=subgraph,
                                     input_mapping=input_mapping,
                                     output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=qrecs[0].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                        if isinstance(node, GlobalPoolingParameters):
                            # Global pooling fused with activations need to have only the activation scale
                            G.quantization[NodeId(pnode,
                                                  node)].out_qs[0] = deepcopy(
                                                      G.quantization[NodeId(
                                                          pnode,
                                                          node)].in_qs[0])
                            G.quantization[NodeId(
                                pnode, node)].out_qs[0].dtype = np.int32
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.pool.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
        for conv_node in [
                params for params in G.nodes()
                if isinstance(params, Conv2DParameters)
        ]:
            node_list = self.get_node_list(G, conv_node, valid_activations)
            if node_list is None or len(node_list.order) < 2:
                continue
            if node_list.fusion_type == 'conv_active_pool':
                if node_list.pool.pool_type == "average":
                    node_list.order = node_list.order[:2:]
                    node_list.pool = None
            elif node_list.fusion_type == 'conv_pool_active':
                # NOTE: This is only for old POW2 kernels - SQ8 can handle this
                if node_list.pool.pool_type == "average" and node_list.active.activation != "relu":
                    continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node,
                                             to_node=node))
                last_node = node
            input_mapping = [[(node_list.conv, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = ConvFusionParameters(
                node_list.conv.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                in_dims_hint=node_list.conv.in_dims_hint,
                out_dims_hint=node_list.conv.out_dims_hint,
                in_dims=deepcopy(node_list.conv.in_dims),
                out_dims=deepcopy(node_list.order[-1].out_dims),
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    # TODO - stats
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=deepcopy(qrecs[0].in_qs),
                                           out_qs=deepcopy(qrecs[-1].out_qs))
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.conv.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
示例#11
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        for pad_node in [
                params for params in G.nodes()
                if isinstance(params, PadParameters)
        ]:
            node_list = self.get_node_list(G, pad_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            padded_input_idx = G.out_edges(node_list.pad.name)[0].to_idx
            subgraph.add_edge(
                NNEdge(from_node=node_list.pad,
                       to_node=node_list.add,
                       to_idx=padded_input_idx))
            last_node = node_list.add
            node_list.add.force_quantized_index = 0
            if node_list.active:
                subgraph.add_edge(
                    NNEdge(from_node=node_list.add, to_node=node_list.active))
                last_node = node_list.active
            if padded_input_idx == 0:
                input_mapping = [[(node_list.pad, 0)], [(node_list.add, 1)]]
            else:
                input_mapping = [[(node_list.add, 0)], [(node_list.pad, 1)]]

            output_mapping = [(last_node, 0)]
            pnode = PaddedAddFusionParameters(
                "PADDED_" + node_list.add.name,
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                # TODO - stats
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[1],
                                           in_qs=qrecs[1].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            if padded_input_idx == 0:
                in_edges = G.in_edges(node_list.pad.name) + \
                    G.indexed_in_edges(node_list.add.name)[1::]
            else:
                in_edges = G.indexed_in_edges(
                    node_list.add.name)[0:1:] + G.in_edges(node_list.pad.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph