Exemplo n.º 1
0
 def replace_function(self, G: NNGraph, subgraph: GraphView):
     nodes = list(subgraph.nodes())
     pnode = ActivationFusion(nodes[0].name + "fusion",
                              nodes[0].op_name + "_active", subgraph)
     nodes[0].step_idx = 0
     nodes[1].step_idx = 1
     LOG.debug("fused nodes %s", ",".join((node.name for node in nodes)))
     if G.quantization:
         qrecs = G.quantization.get_all(subgraph.nodes())
         if qrecs:
             if isinstance(qrecs[0],
                           (SymmetricQuantizationRecord,
                            SymmetricScalableFilterQuantizationRecord)):
                 prec = SymmetricQuantizationRecord(in_qs=qrecs[0].in_qs,
                                                    out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0],
                             (MultQuantizationRecord,
                              MultScalableFilterQuantizationRecord)):
                 prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs,
                                               out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0],
                             (Float32QuantizationRecord,
                              Float32ScalableFilterQuantizationRecord)):
                 prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs,
                                                  out_qs=qrecs[-1].out_qs)
             for node in subgraph.nodes():
                 G.quantization.move_to_fusion(node, pnode)
             G.quantization[NodeId(pnode)] = prec
     return pnode
Exemplo n.º 2
0
 def replace_function(self, G: GraphView, subgraph: GraphView):
     if not self.validate_match(subgraph):
         raise DontReplaceError()
     step = 0
     for node in subgraph.nodes():
         node.step_idx = step
         step = step + 1
         if isinstance(node, Conv2DParameters):
             conv_name = node.name + "_fusion"
             break
     LOG.debug("fused nodes %s", ",".join((node.name for node in subgraph.nodes())))
     # simple node order is necessary because nodes() will not necessarily
     # be in order
     pnode = ConvFusionParameters(conv_name, fusion_type=self.fusion_type, subgraph=subgraph)
     if G.quantization:
         qrecs = G.quantization.get_all(pnode.contained_nodes())
         if qrecs:
             if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                 prec = SymmetricQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                 prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                 prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             for node in pnode.contained_nodes():
                 G.quantization.move_to_fusion(node, pnode)
             G.quantization[NodeId(pnode)] = prec
     return pnode, None, None
Exemplo n.º 3
0
 def replace_function(self, G: NNGraph, subgraph: GraphView):
     step = 0
     for node in subgraph.nodes():
         node.step_idx = step
         step = step + 1
         if isinstance(node, FcParameters):
             linear_name = node.name + "_fusion"
             break
     LOG.info("fusing nodes %s", ",".join(
         (node.name for node in subgraph.nodes())))
     # simple node order is necessary because nodes() will not necessarily
     # be in order
     pnode = ConvFusionParameters(linear_name, fusion_type="linear_active", subgraph=subgraph)
     if G.quantization:
         qrecs = G.quantization.get_all(pnode.contained_nodes())
         if qrecs:
             if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                 prec = SymmetricQuantizationRecord(
                     in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                 prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                 prec = Float32QuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
             for node in pnode.contained_nodes():
                 G.quantization.move_to_fusion(node, pnode)
             G.quantization[NodeId(pnode)] = prec
     return pnode, None, None
Exemplo n.º 4
0
 def replace_function(self, G: GraphView, subgraph: GraphView):
     for node in subgraph.nodes():
         if isinstance(node, Conv2DParameters):
             conv_name = node.name + "_fusion"
             break
     LOG.debug("fused nodes %s", ",".join(
         (node.name for node in subgraph.nodes())))
     # simple node order is necessary because nodes() will not necessarily
     # be in order
     return FusionParameters(conv_name, self.fusion_type,
                             [node for node in subgraph.dfs()])
Exemplo n.º 5
0
 def replace_function(self, G: GraphView, subgraph: GraphView):
     step = 0
     for node in subgraph.nodes():
         node.step_idx = step
         step = step + 1
         if isinstance(node, FcParameters):
             linear_name = node.name + "_fusion"
             break
     LOG.debug("fused nodes %s", ",".join(
         (node.name for node in subgraph.nodes())))
     # simple node order is necessary because nodes() will not necessarily
     # be in order
     return FusionParameters(linear_name, "linear_active", subgraph)
Exemplo n.º 6
0
 def replace_function(self, G: GraphView, subgraph: GraphView):
     if not self.validate_match(subgraph):
         raise DontReplaceError()
     step = 0
     for node in subgraph.nodes():
         node.step_idx = step
         step = step + 1
         if isinstance(node, Conv2DParameters):
             conv_name = node.name + "_fusion"
             break
     LOG.debug("fused nodes %s", ",".join(
         (node.name for node in subgraph.nodes())))
     # simple node order is necessary because nodes() will not necessarily
     # be in order
     return FusionParameters(conv_name, self.fusion_type, subgraph)
Exemplo n.º 7
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False

        for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())):
            node_list = self.get_node_list(G, node,
                                           FusionMatch(self._default_ktype))
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for snode in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=snode))
                last_node = snode
            # assumption here is that the first node could have multiple inputs but definitely has only
            # one output
            input_mapping = [[
                (node_list.node, idx)
            ] for idx in range(G.num_in_edges(node_list.node.name))]
            output_mapping = [(last_node, 0)]
            pnode = node_list.fusions_class(node_list.node.name + '_fusion',
                                            fusion_type=node_list.fusion_type,
                                            subgraph=subgraph,
                                            input_mapping=input_mapping,
                                            output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=qrecs[0].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for fnode in pnode.contained_nodes():
                        G.quantization.move_to_fusion(fnode, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.node.name)
            out_edges = G.out_edges(last_node.name)
            for snode in node_list.order:
                G.remove(snode)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def match(self, G: GraphView, set_identity: bool = True):
        if not G.quantization:
            return
        input_dict = {}
        for node in G.nodes():
            if not self.can_change_output(node):
                continue
            all_matches = []
            for succ in [
                    succ for succs in G.successors(node.name) for succ in succs
            ]:
                matches = self.can_change_input(G, succ)
                if matches is None:
                    all_matches = None
                    break
                all_matches += matches
            if all_matches is None:
                continue
            input_dict[node] = all_matches

        input_dict = self.validate_multi_input(G, input_dict)
        for node in input_dict:
            # all nodes that can currently change output have one output
            self.do_change(G, node)

        if set_identity:
            self.set_identity(G)
Exemplo n.º 9
0
 def match(self, G: GraphView, set_identity: bool = True):
     split_nodes = [
         node for node in G.nodes() if isinstance(node, SplitParameters)
     ]
     has_modified_graph = False
     for node in split_nodes:
         # traverse reshapes or transposes that do nothing - check gen
         # find edges connected to concats
         res = self.find_split_concat(G, node)
         if res is None:
             continue
         # TODO(martin) - group edges that have adjacent inputs and outputs
         if G.quantization:
             qrec = G.quantization[NodeId(node)]
         for idx, bundle in enumerate(res):
             if not bundle:
                 continue
             has_modified_graph = True
             copy_node = CopyParameters("%s_copy_%s" % (node.name, idx))
             for edge_set in bundle:
                 first_edge = edge_set[0]
                 G.remove_edge(first_edge)
                 G.add_edge(
                     NNEdge(copy_node,
                            first_edge.to_node,
                            to_idx=first_edge.to_idx))
             G.add_edge(NNEdge(node, copy_node, from_idx=idx))
             if G.quantization:
                 G.quantization[NodeId(copy_node)] = qrec.__class__(
                     in_qs=deepcopy(qrec.out_qs[idx]),
                     out_qs=deepcopy(qrec.out_qs[idx]))
     return has_modified_graph
Exemplo n.º 10
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False

        for node in G.nodes(node_classes=MatMulOpParameters):
            in_edges = [edge for edge in G.indexed_in_edges(node.name)]
            trans_node = in_edges[1].from_node
            if not isinstance(trans_node, TransposeParameters):
                continue
            if isinstance(node, MatMulTransposedParameters):
                new_node = MatMulOpParameters(node.name)
            else:
                new_node = MatMulTransposedParameters(node.name)

            in_trans_edge = [
                edge for edge in G.indexed_in_edges(trans_node.name)
            ][0]
            G.replace_node(node.name, new_node)
            G.remove(trans_node)
            G.add_edge(
                NNEdge(in_trans_edge.from_node,
                       new_node,
                       from_idx=in_trans_edge.from_idx,
                       to_idx=1))
            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Exemplo n.º 11
0
    def match(self, G: GraphView, set_identity: bool = True):
        if not G.quantization:
            return
        sigs_swishes = [
            node for node in G.nodes()
            if isinstance(node, (SigmoidActivationParameters,
                                 HSigmoidActivationParameters,
                                 HSwishActivationParameters))
        ]
        qrecs = [G.quantization[NodeId(node)] for node in sigs_swishes]
        for sig_swish, qrec in zip(sigs_swishes, qrecs):
            in_edge = [
                edge for edge in G.in_edges(sig_swish.name) if edge.to_idx == 0
            ][0]
            in_q = qrec.in_qs[0]
            min_val, max_val = in_q.min_val, in_q.max_val
            if isinstance(
                    sig_swish,
                (HSigmoidActivationParameters, SigmoidActivationParameters)):
                # Hard sigmoid implements a RELU, be sure 6 can be representable
                min_val, max_val = 0, 6
            elif isinstance(sig_swish, HSwishActivationParameters):
                min_val, max_val = 0, in_q.max_val * 6

            new_in_q = QType.from_min_max_sq(min_val=min_val,
                                             max_val=max_val,
                                             dtype=in_q.dtype)
            propagate_qtype_up(G, new_in_q, in_edge)

        if set_identity:
            self.set_identity(G)

        return False
Exemplo n.º 12
0
 def match(self, G: GraphView, set_identity: bool = True):
     # get a list of all the nodes that are transposable but not transposes
     # Need to do this first to avoid mutating it when doing the modifications
     tnodes = list(filter(lambda n: isinstance(n, Transposable) and\
                             not isinstance(n, TransposeParameters),
                          G.nodes()))
     for node in tnodes:
         if node.transpose_in:
             for idx, edge in enumerate(G.in_edges(node.name)):
                 in_params = TransposeParameters("%s_TIN_%s" % (node.name, idx),
                                                 transpose=node.transpose_in)
                 if node.in_dims_hint:
                     in_hint = node.in_dims_hint[edge.to_idx]
                     out_hint = apply_reverse_transpose_to_hint(in_hint, node.transpose_in)
                     in_params.in_dims_hint = [in_hint.copy()]
                     in_params.out_dims_hint = [out_hint.copy()]
                     node.in_dims_hint[edge.to_idx] = out_hint
                 G.insert_node(in_params, edge.from_node.name, edge.to_node.name,
                               from_idx=edge.from_idx, to_idx=edge.to_idx)
             node.transpose_in = None
         if node.transpose_out:
             for idx, edge in enumerate(G.out_edges(node.name)):
                 out_params = TransposeParameters("%s_TOUT_%s" % (node.name, idx),
                                                  transpose=node.transpose_out)
                 if node.out_dims_hint:
                     out_hint = node.out_dims_hint[edge.from_idx]
                     in_hint = apply_reverse_transpose_to_hint(out_hint, node.transpose_out)
                     out_params.in_dims_hint = [in_hint.copy()]
                     out_params.out_dims_hint = [out_hint.copy()]
                     node.out_dims_hint[edge.from_idx] = in_hint
                 G.insert_node(out_params, edge.from_node.name, edge.to_node.name,
                               from_idx=edge.from_idx, to_idx=edge.to_idx)
             node.transpose_out = None
     if set_identity:
         self.set_identity(G)
Exemplo n.º 13
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for node in [
                node for node in G.nodes() if self.node_does_nothing(G, node)
        ]:
            has_modified_graph = True
            in_edge = G.in_edges(node.name)[0]
            G.remove_edge(in_edge)
            for out_edge in G.out_edges(node.name):
                G.remove_edge(out_edge)
                G.add_edge(
                    NNEdge(in_edge.from_node,
                           out_edge.to_node,
                           from_idx=in_edge.from_idx,
                           to_idx=out_edge.to_idx))
            LOG.info(f'removing {node.name} that does nothing')
            G.remove(node)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Exemplo n.º 14
0
    def replace_function(self, G: GraphView, subgraph: GraphView):
        relu_node = None
        constant_node = None
        mul_node = None
        for node in subgraph.nodes():
            if isinstance(node, ReluActivationParameters):
                relu_node = node
            elif isinstance(node, ConstantInputParameters):
                constant_node = node
            elif isinstance(node, MatrixMulParameters):
                mul_node = node

        activation = HSigmoidActivationParameters(mul_node.name +
                                                  "_fused_close_hsigmoid",
                                                  offset=0)

        if G.quantization:
            reluqrec = G.quantization[NodeId(relu_node)]
            mulqrec = G.quantization[NodeId(mul_node)]
            del G.quantization[NodeId(constant_node)]
            pqrec = QRec.copy_ktype(reluqrec,
                                    in_qs=reluqrec.in_qs,
                                    out_qs=mulqrec.out_qs)
            G.quantization[NodeId(activation)] = pqrec
        return activation, None, None
Exemplo n.º 15
0
    def replace_function(self, G: GraphView, subgraph: GraphView):
        relu_node = None
        constant_node = None
        mul_node = None
        for node in subgraph.nodes():
            if isinstance(node, ReluActivationParameters):
                relu_node = node
            elif isinstance(node, ConstantInputParameters):
                constant_node = node
            elif isinstance(node, MatrixMulParameters):
                mul_node = node

        activation = HSigmoidActivationParameters(mul_node.name +
                                                  "_fused_close_hsigmoid",
                                                  offset=0)

        if G.quantization:
            reluqrec = G.quantization[NodeId(relu_node)]
            mulqrec = G.quantization[NodeId(mul_node)]
            del G.quantization[NodeId(constant_node)]
            if isinstance(reluqrec, (SymmetricQuantizationRecord)):
                pqrec = SymmetricQuantizationRecord(in_qs=reluqrec.in_qs,
                                                    out_qs=mulqrec.out_qs)
            elif isinstance(reluqrec, (MultQuantizationRecord)):
                pqrec = MultQuantizationRecord(in_qs=reluqrec.in_qs,
                                               out_qs=mulqrec.out_qs)
            elif isinstance(reluqrec, (Float32QuantizationRecord)):
                pqrec = Float32QuantizationRecord(in_qs=reluqrec.in_qs,
                                                  out_qs=mulqrec.out_qs)
            else:
                raise NotImplementedError()
            G.quantization[NodeId(activation)] = pqrec
        return activation, None, None
Exemplo n.º 16
0
    def match(self, G: GraphView, set_identity: bool = True):
        if not G.quantization:
            return
        softmaxes = [
            node for node in G.nodes() if isinstance(node, SoftMaxParameters)
        ]
        qrecs = [G.quantization[NodeId(node)] for node in softmaxes]
        if not all(isinstance(qrec, MultQuantizationRecord) for qrec in qrecs):
            return
        for softmax, qrec in zip(softmaxes, qrecs):
            in_q = qrec.in_qs[0]
            in_q.scale_to_pow2()
            for edge in G.in_edges(softmax.name):
                propagate_qtype_up(G, in_q, edge)
            for edge in G.out_edges(softmax.name):
                assert isinstance(
                    edge.to_node,
                    (OutputParameters, QuantizeParameters
                     )), "Softmax is supported only at the end of the graph"
                out_qrec = G.quantization[NodeId(edge.to_node)]
                out_qrec.in_qs[0] = qrec.out_qs[0]
                out_qrec.out_qs[0] = qrec.out_qs[0]

        if set_identity:
            self.set_identity(G)

        return False
Exemplo n.º 17
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        split_nodes = [node for node in G.nodes(node_classes=SplitParameters)]
        has_modified_graph = False
        for node in split_nodes:
            has_modified_graph = self.find_direct_connects(
                G, node, has_modified_graph)
        concat_nodes = [
            node for node in G.nodes(node_classes=ConcatParameters)
        ]
        for node in concat_nodes:
            has_modified_graph = self.find_direct_connects(G,
                                                           node,
                                                           has_modified_graph,
                                                           find_output=False)

        return has_modified_graph
Exemplo n.º 18
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for node in [
                node for node in G.nodes(node_classes=StridedSliceParameters)
        ]:
            if node.slice_shape != tuple(node.in_dims[0].shape):
                continue
            has_modified_graph = True
            nid = NodeId(node)
            if node.slice_shape == node.out_shape:
                LOG.info(
                    f'removing strided slice {node.name} that does nothing')
                G.remove_and_reconnect(node, edge_class=NNEdge)
                if G.quantization and nid in G.quantization:
                    del G.quantization[nid]
            else:
                reshape = ReshapeParameters(
                    G.unique_name(f'{node.name}_reshape'),
                    old_shape=node.slice_shape,
                    shape=node.out_shape)
                LOG.info(
                    f'replacing strided slice {node.name} with reshape {reshape.name}'
                )
                G.replace_node(node, reshape)
                if G.quantization and nid in G.quantization:
                    G.quantization[NodeId(reshape)] = G.quantization[nid]
                    del G.quantization[nid]

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Exemplo n.º 19
0
 def replace_function(self, G: GraphView, subgraph: GraphView):
     filter_node = None
     constant_node = None
     for node in subgraph.nodes():
         if isinstance(node, FilterParameters):
             filter_node = node
         elif isinstance(node, ConstantInputParameters):
             constant_node = node
     LOG.info("fusing bias in %s into %s", constant_node.name,
              filter_node.name)
     flattened_constant = constant_node.value.flatten()
     # shape needs to match
     if flattened_constant.shape[0] == filter_node.filter.out_c:
         if filter_node.has_bias:
             assert filter_node.biases is not None, "can't absorb bias into filter. maybe weights are not loaded"
             filter_node.biases += flattened_constant
         else:
             filter_node.biases = flattened_constant
     else:
         raise DontReplaceError()
     if G.quantization:
         fnid = NodeId(filter_node)
         cnid = NodeId(constant_node)
         if fnid in G.quantization and cnid in G.quantization:
             G.quantization[fnid].biases_q = G.quantization[cnid].out_qs[0]
     return filter_node, None, None
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        nodes_removed = []
        modified_graph = False
        for node in G.nodes(node_classes=QuantizeParameters):
            if issubclass(node.from_qtype.dtype, (np.floating, bfloat16)):
                if issubclass(node.to_qtype.dtype, (np.floating, bfloat16)):
                    LOG.warning(
                        'node %s quantizes from floating type to floating type and cannot directly be removed',
                        node.name)
                    continue
                if self.propagate_up(G, node, node.to_qtype):
                    modified_graph = True
                    nodes_removed.append(node)
                    G.remove_and_reconnect(node, edge_class=NNEdge)
                    if G.quantization:
                        del G.quantization[NodeId(node)]
                else:
                    LOG.warning('unable to remove quantize node %s', node.name)
            else:
                if self.propagate_down(G, node, node.from_qtype):
                    modified_graph = True
                    nodes_removed.append(node)
                    G.remove_and_reconnect(node, edge_class=NNEdge)
                    if G.quantization:
                        del G.quantization[NodeId(node)]
                else:
                    LOG.warning('unable to remove quantize node %s', node.name)

        if set_identity:
            self.set_identity(G)

        return modified_graph
Exemplo n.º 21
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = False
        candidates = [node for node in G.nodes()
                      if len(G.indexed_out_edges(node.name)) == 1 and len(G.out_edges(node.name)) > 1]
        while candidates:
            node = candidates.pop(0)
            strings = self.explore(G, [node])
            if not strings:
                continue
            modified_graph = True
            primary = strings.pop(0)
            for pnode in primary:
                if pnode in candidates:
                    candidates.remove(pnode)
            out_edges = []
            for other in strings:
                out_edges.extend(G.out_edges(other[-1].name))
                for other_node in other:
                    if other_node in candidates:
                        candidates.remove(other_node)
                    G.remove(other_node)
                    nid = NodeId(other_node)
                    if G.quantization and nid in G.quantization:
                        del G.quantization[nid]
                LOG.info(
                    f'removed duplicates from {primary[0].name} {",".join(node.name for node in other)}')
            pend = primary[-1]
            for edge in out_edges:
                G.add_edge(
                    NNEdge(from_node=pend, to_node=edge.to_node, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return modified_graph
Exemplo n.º 22
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified = False
        for node in G.nodes(node_classes=ConstantInputParameters):
            out_edges = G.out_edges(node.name)
            if len(out_edges) <= 1:
                continue
            has_modified = True
            LOG.info(
                'node %s has more than one out edge and will be duplicated',
                node.name)
            idx = 1
            for out_edge in out_edges[1::]:
                new_constant = ConstantInputParameters(f'{node.name}_{idx}',
                                                       dims=Dim.unnamed(
                                                           node.dims.shape),
                                                       value=node.value.copy())
                G.remove_edge(out_edge)
                G.add_edge(
                    NNEdge(from_node=new_constant,
                           to_node=out_edge.to_node,
                           to_idx=out_edge.to_idx))
                idx += 1

        if set_identity:
            self.set_identity(G)

        return has_modified
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        const_ops = [node for node in G.nodes()
                     if isinstance(node, MatrixMulParameters)
                     and any([isinstance(edge.from_node, ConstantInputParameters)
                              and check_equals(G, edge.from_node, 1.0/6.0)
                              for edge in G.in_edges(node.name)])]

        oprecs = [oprec for oprec in (look_back(G, op)
                                      for op in const_ops)
                  if oprec is not None]
        has_modified_graph = False
        for oprec in oprecs:
            mul_edge = G.out_edges(oprec['mul'][0].name)
            if len(mul_edge) == 1:
                mul_edge = mul_edge[0]
                if isinstance(mul_edge.to_node, ReluActivationParameters):
                    oprec['relu3'] = (mul_edge.to_node,
                                      G.quantization[NodeId(mul_edge.to_node)])
            has_modified_graph = True
            process_rec(G, oprec)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Exemplo n.º 24
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        rnn_nodes = [
            self.find_unpack(G, node) for node in G.nodes()
            if isinstance(node, RNNBaseParameters) and node.n_output_cells > 1
        ]
        rnn_nodes_by_slice = self.validate_slices(G, rnn_nodes)
        rnn_nodes_by_slice = self.validate_multi_branch(G, rnn_nodes_by_slice)
        if not rnn_nodes_by_slice:
            return False

        for unpack_node, rnn_unpacks in rnn_nodes_by_slice.items():
            modified_nodes = set()
            for rnn_unpack in rnn_unpacks:
                self.process_path(G, rnn_unpack, modified_nodes)
            # since process path will have removed all unnecessary nodes the edges will be correct here
            out_edges = G.out_edges(unpack_node.name)
            in_edges = G.in_edges(unpack_node.name)
            assert len(in_edges
                       ) == 1, "expecting unpack node to have only one in edge"
            in_edge = in_edges[0]
            changes_shape = unpack_node.changes_shape if isinstance(
                unpack_node, StridedSliceParameters) else False

            LOG.info("Eliminating last cell unpack: %s", unpack_node.name)
            G.remove(unpack_node)

            # Here the strided slice can change the output shape of the RNN
            # so insert a reshape to do the shape change
            if changes_shape:
                reshape = ReshapeParameters(
                    unpack_node.name + '_reshape',
                    old_shape=Dim.unnamed(unpack_node.post_slice_shape),
                    shape=Dim.unnamed(unpack_node.out_shape))
                G.add_edge(
                    NNEdge(from_node=in_edge.from_node,
                           to_node=reshape,
                           from_idx=in_edge.from_idx))
                for out_edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=reshape,
                               to_node=out_edge.to_node,
                               to_idx=out_edge.to_idx))
                if G.quantization:
                    G.quantization[NodeId(reshape)] = G.quantization[NodeId(
                        unpack)]
            else:
                for out_edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=in_edge.from_node,
                               to_node=out_edge.to_node,
                               from_idx=in_edge.from_idx,
                               to_idx=out_edge.to_idx))
            if G.quantization:
                del G.quantization[NodeId(unpack_node)]

        if set_identity:
            self.set_identity(G)

        return True
Exemplo n.º 25
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]:
            node_list = self.get_node_list(G, conv_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            if node_list.fusion_type == 'conv_active_pool':
                if node_list.pool.pool_type == "average":
                    node_list.order = node_list.order[:2:]
                    node_list.pool = None
            elif node_list.fusion_type == 'conv_pool_active':
                if node_list.pool.pool_type == "average" and node_list.active.activation != "relu":
                    continue
            LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.conv, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = ConvFusionParameters(
                node_list.conv.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                in_dims_hint=node_list.conv.in_dims_hint,
                out_dims_hint=node_list.conv.out_dims_hint,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.conv.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Exemplo n.º 26
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        something_changed = False
        filt_nodes = [node for node in G.nodes()
                      if isinstance(node, (Conv2DParameters, ConvFusionParameters))]
        for filt_node in filt_nodes:
            pnode = filt_node
            if isinstance(filt_node, ConvFusionParameters):
                cnodes = filt_node.contained_nodes()
                filt_node = cnodes[0]
            if not isinstance(filt_node, Conv2DParameters):
                continue
            in_dim = filt_node.in_dims
            filt_dim = filt_node.filter
            if filt_dim.h <= in_dim[0].h and filt_dim.w <= in_dim[0].w:
                continue

            min_h = min(filt_dim.h, in_dim[0].h)
            min_w = min(filt_dim.w, in_dim[0].w)
            if min_h > 1 and min_w > 1:
                LOG.warning("Filter of %s [%dx%d] bigger than input [%dx%d] not optimal but will work on AT",
                            filt_node.name, filt_dim.h, filt_dim.w, in_dim[0].h, in_dim[0].w)
                continue

            ker_h = 1 if min_h == 1 else filt_dim.h
            ker_w = 1 if min_w == 1 else filt_dim.w
            if ker_h == filt_dim.h and ker_w == filt_dim.w:
                continue
            new_filt_dim = Conv2DFilterDim(
                ker_h, ker_w, filt_dim.out_c, in_c=filt_dim.in_c)
            LOG.warning("Converting filter of %s from [%dx%d] -> [%dx%d]",
                        filt_node.name, filt_dim.h, filt_dim.w, new_filt_dim.h, new_filt_dim.w)
            filt_node.filter = new_filt_dim
            new_w_idxs = []
            for dim in filt_dim.order:
                if dim in ('out_c', 'in_c'):
                    new_w_idxs.append(slice(None))
                elif dim == 'h':
                    if new_filt_dim.h == 1:
                        new_w_idxs.append(
                            slice(filt_node.padding.t, filt_node.padding.t + 1))
                    else:
                        new_w_idxs.append(slice(0, new_filt_dim.h))
                elif dim == 'w':
                    if new_filt_dim.w == 1:
                        new_w_idxs.append(
                            slice(filt_node.padding.l, filt_node.padding.l + 1))
                    else:
                        new_w_idxs.append(slice(0, new_filt_dim.w))
            weights_node = G.indexed_in_edges(pnode.name)[1].from_node
            weights_node.value = weights_node.value[tuple(new_w_idxs)]
            weights_node.dims = Dim.unnamed(weights_node.value.shape)
            something_changed = True

        if set_identity:
            self.set_identity(G)

        return something_changed
Exemplo n.º 27
0
 def match(self, G: GraphView, set_identity: bool = True):
     # get a list of all the nodes that are transposable but not transposes
     # Need to do this first to avoid mutating it when doing the modifications
     tnodes = list(filter(lambda n: isinstance(n, Transposable) and
                          not isinstance(n, TransposeParameters),
                          G.nodes()))
     has_modified_graph = False
     for node in tnodes:
         if node.transpose_in:
             for idx, edge in enumerate(G.in_edges(node.name)):
                 if edge.to_idx >= len(node.transpose_in):
                     continue
                 trans = node.transpose_in[edge.to_idx]
                 if trans is None:
                     continue
                 has_modified_graph = True
                 in_params = TransposeParameters("%s_TIN_%s" % (node.name, idx),
                                                 transpose=trans)
                 if node.in_dims_hint and node.in_dims_hint[edge.to_idx]:
                     in_hint = node.in_dims_hint[edge.to_idx]
                     out_hint = apply_reverse_transpose_to_hint(in_hint, trans)
                     in_params.in_dims_hint = [in_hint.copy()]
                     in_params.out_dims_hint = [out_hint.copy()]
                     node.in_dims_hint[edge.to_idx] = out_hint
                 if G.quantization:
                     G.quantization.copy_to_node(node, in_params)
                 G.insert_node(in_params, edge.from_node.name, edge.to_node.name,
                               from_idx=edge.from_idx, to_idx=edge.to_idx,
                               edge_class=NNEdge)
             node.transpose_in = None
         if node.transpose_out:
             for idx, edge in enumerate(G.out_edges(node.name)):
                 if edge.from_idx >= len(node.transpose_out):
                     continue
                 trans = node.transpose_out[edge.from_idx]
                 if trans is None:
                     continue
                 has_modified_graph = True
                 out_params = TransposeParameters("%s_TOUT_%s" % (node.name, idx),
                                                  transpose=trans)
                 if node.out_dims_hint:
                     out_hint = node.out_dims_hint[edge.from_idx]
                     in_hint = apply_reverse_transpose_to_hint(out_hint, trans)
                     out_params.in_dims_hint = [in_hint.copy()]
                     out_params.out_dims_hint = [out_hint.copy()]
                     node.out_dims_hint[edge.from_idx] = in_hint
                 if G.quantization:
                     G.quantization.copy_to_node(node, out_params)
                 G.insert_node(out_params, edge.from_node.name, edge.to_node.name,
                               from_idx=edge.from_idx, to_idx=edge.to_idx,
                               edge_class=NNEdge)
             node.transpose_out = None
     if set_identity:
         self.set_identity(G)
     return has_modified_graph
Exemplo n.º 28
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = True
        while modified_graph:
            modified_graph = False
            for reshape in G.nodes(node_classes=(ReshapeParameters, )):
                if not reshape.has_transpose and reshape.shape.shape == reshape.old_shape.shape:
                    modified_graph = True
                    LOG.info('removing reshape that does nothing %s',
                             reshape.name)
                    G.remove_and_reconnect(reshape, edge_class=NNEdge)
                    nid = NodeId(reshape)
                    if G.quantization and nid in G.quantization:
                        del G.quantization[nid]
            res = None
            for reshape in G.nodes(node_classes=(ReshapeParameters, )):
                res = self.validate_reshape(G, reshape)
                if res:
                    LOG.info('unnecessary reshape found after %s',
                             reshape.name)
                    modified_graph = True
                    (reshape, candidates, out_shape) = res
                    for candidate in candidates:
                        LOG.info(
                            'removing unnecessary reshape or transpose %s',
                            candidate.name)
                        edges = G.out_edges(candidate.name)
                        G.remove(candidate)
                        nid = NodeId(candidate)
                        if G.quantization and nid in G.quantization:
                            del G.quantization[nid]
                        for edge in edges:
                            G.add_edge(
                                NNEdge(from_node=reshape,
                                       to_node=edge.to_node,
                                       to_idx=edge.to_idx))
                    reshape.shape = Dim.unnamed(out_shape)
                    break

        if set_identity:
            self.set_identity(G)

        return modified_graph
Exemplo n.º 29
0
def split_down_from(cur_g, node, res_g=None):
    """ split cur_g into 2 graphs. Everything from node down and the rest """
    if res_g is None:
        res_g = GraphView()
    out_edges = cur_g.out_edges(node.name)
    cur_g.remove(node)
    if node not in res_g.nodes():
        res_g.add_node(node)
    for edge in out_edges:
        res_g.add_edge(edge.clone())
        split_down_from(cur_g, edge.to_node, res_g=res_g)
    return res_g
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
        for fc_node in [params for params in G.nodes() if isinstance(params, FcParameters)]:
            node_list = self.get_node_list(G, fc_node, valid_activations)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.linear, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = LinearFusionParameters(
                node_list.linear.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(
                        qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.linear.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode,
                                  from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node,
                                  from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph