Esempio n. 1
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        fragment = GraphMatcher(
            match_function=lambda state, frag: (frag, state['match']))
        fragment.add_node(MatScaleNodeMatch())
        has_modified_graph = False
        for frag, match in fragment.match_graph(G):
            match_edges = [
                G.indexed_in_edges(node.name)[idx]
                for node, idx in match['inputs']
            ]
            matched_node = list(frag.nodes())[0]
            out_edges = G.out_edges(matched_node.name)
            has_modified_graph = True
            G.remove(matched_node)
            fnode = MatScaleFusionParameters(
                "{}_fusion".format(matched_node.name),
                fusion_type=match['type'],
                subgraph=frag,
                input_mapping=[[(matched_node, 0)], [(matched_node, 1)]])
            G.add_node(fnode)
            for idx, edge in enumerate(match_edges):
                edge.to_node = fnode
                edge.to_idx = idx
                G.add_edge(edge)
            for edge in out_edges:
                edge.from_node = fnode
                G.add_edge(edge)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
def reverse_matmul(G: GraphView, params):
    # reverse edges
    in_edges = G.indexed_in_edges(params.name)
    for edge in in_edges[0:2:]:
        G.remove_edge(edge)
    other_idx = 1
    for edge in in_edges[0:2:]:
        G.add_edge(
            NNEdge(from_node=edge.from_node,
                   to_node=params,
                   from_idx=edge.from_idx,
                   to_idx=other_idx))
        other_idx = 1 - other_idx
    nid = NodeId(params)
    if G.quantization and nid in G.quantization:
        qrec = G.quantization[nid]
        # swap qrecs
        qrec.in_qs[0], qrec.in_qs[1] = qrec.in_qs[1], qrec.in_qs[0]

    # add transposes
    in_nodes = []
    for idx in range(2):
        tin_params = TransposeParameters(
            G.unique_name(f"{params.name}_tin{idx}"), transpose=(1, 0))
        in_nodes.append(tin_params)
        G.insert_node_before(tin_params, params, to_idx=idx, edge_class=NNEdge)
    tout_params = TransposeParameters(G.unique_name(f"{params.name}_tout"),
                                      transpose=(1, 0))
    G.insert_node_after(params, tout_params)
    return in_nodes, tout_params
Esempio n. 3
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = False
        candidates = [node for node in G.nodes()
                      if len(G.indexed_out_edges(node.name)) == 1 and len(G.out_edges(node.name)) > 1]
        while candidates:
            node = candidates.pop(0)
            strings = self.explore(G, [node])
            if not strings:
                continue
            modified_graph = True
            primary = strings.pop(0)
            for pnode in primary:
                if pnode in candidates:
                    candidates.remove(pnode)
            out_edges = []
            for other in strings:
                out_edges.extend(G.out_edges(other[-1].name))
                for other_node in other:
                    if other_node in candidates:
                        candidates.remove(other_node)
                    G.remove(other_node)
                    nid = NodeId(other_node)
                    if G.quantization and nid in G.quantization:
                        del G.quantization[nid]
                LOG.info(
                    f'removed duplicates from {primary[0].name} {",".join(node.name for node in other)}')
            pend = primary[-1]
            for edge in out_edges:
                G.add_edge(
                    NNEdge(from_node=pend, to_node=edge.to_node, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return modified_graph
Esempio n. 4
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False

        for node in G.nodes(node_classes=MatMulOpParameters):
            in_edges = [edge for edge in G.indexed_in_edges(node.name)]
            trans_node = in_edges[1].from_node
            if not isinstance(trans_node, TransposeParameters):
                continue
            if isinstance(node, MatMulTransposedParameters):
                new_node = MatMulOpParameters(node.name)
            else:
                new_node = MatMulTransposedParameters(node.name)

            in_trans_edge = [
                edge for edge in G.indexed_in_edges(trans_node.name)
            ][0]
            G.replace_node(node.name, new_node)
            G.remove(trans_node)
            G.add_edge(
                NNEdge(in_trans_edge.from_node,
                       new_node,
                       from_idx=in_trans_edge.from_idx,
                       to_idx=1))
            has_modified_graph = True

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 5
0
 def match(self, G: GraphView, set_identity: bool = True):
     split_nodes = [
         node for node in G.nodes() if isinstance(node, SplitParameters)
     ]
     has_modified_graph = False
     for node in split_nodes:
         # traverse reshapes or transposes that do nothing - check gen
         # find edges connected to concats
         res = self.find_split_concat(G, node)
         if res is None:
             continue
         # TODO(martin) - group edges that have adjacent inputs and outputs
         if G.quantization:
             qrec = G.quantization[NodeId(node)]
         for idx, bundle in enumerate(res):
             if not bundle:
                 continue
             has_modified_graph = True
             copy_node = CopyParameters("%s_copy_%s" % (node.name, idx))
             for edge_set in bundle:
                 first_edge = edge_set[0]
                 G.remove_edge(first_edge)
                 G.add_edge(
                     NNEdge(copy_node,
                            first_edge.to_node,
                            to_idx=first_edge.to_idx))
             G.add_edge(NNEdge(node, copy_node, from_idx=idx))
             if G.quantization:
                 G.quantization[NodeId(copy_node)] = qrec.__class__(
                     in_qs=deepcopy(qrec.out_qs[idx]),
                     out_qs=deepcopy(qrec.out_qs[idx]))
     return has_modified_graph
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for node in [
                node for node in G.nodes() if self.node_does_nothing(G, node)
        ]:
            has_modified_graph = True
            in_edge = G.in_edges(node.name)[0]
            G.remove_edge(in_edge)
            for out_edge in G.out_edges(node.name):
                G.remove_edge(out_edge)
                G.add_edge(
                    NNEdge(in_edge.from_node,
                           out_edge.to_node,
                           from_idx=in_edge.from_idx,
                           to_idx=out_edge.to_idx))
            LOG.info(f'removing {node.name} that does nothing')
            G.remove(node)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 7
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified = False
        for node in G.nodes(node_classes=ConstantInputParameters):
            out_edges = G.out_edges(node.name)
            if len(out_edges) <= 1:
                continue
            has_modified = True
            LOG.info(
                'node %s has more than one out edge and will be duplicated',
                node.name)
            idx = 1
            for out_edge in out_edges[1::]:
                new_constant = ConstantInputParameters(f'{node.name}_{idx}',
                                                       dims=Dim.unnamed(
                                                           node.dims.shape),
                                                       value=node.value.copy())
                G.remove_edge(out_edge)
                G.add_edge(
                    NNEdge(from_node=new_constant,
                           to_node=out_edge.to_node,
                           to_idx=out_edge.to_idx))
                idx += 1

        if set_identity:
            self.set_identity(G)

        return has_modified
Esempio n. 8
0
 def match_function(self, G: GraphView):
     sub = GraphView()
     sub.add_node(MatchNode('0', matcher=lambda node:\
             isinstance(node, PadParameters)))
     sub.add_node(MatchNode('1', matcher=lambda node:\
             isinstance(node, FilterLikeParameters) and\
             self.has_no_padding(node)))
     sub.add_edge(Edge('0', '1'))
     return G.match_fragment(sub)
Esempio n. 9
0
 def match_function(self, G: GraphView):
     sub = GraphView()
     sub.add_node(MatchNode('0',
                            matcher=lambda node:
                            isinstance(node, FcParameters) and
                            self.valid_linear(node)))
     sub.add_node(MatchNode('1', matcher=lambda node:
                            isinstance(node, ActivationParameters) and
                            self.valid_activation(node)))
     sub.add_edge(Edge('0', '1'))
     return G.match_fragment(sub)
Esempio n. 10
0
def split_down_from(cur_g, node, res_g=None):
    """ split cur_g into 2 graphs. Everything from node down and the rest """
    if res_g is None:
        res_g = GraphView()
    out_edges = cur_g.out_edges(node.name)
    cur_g.remove(node)
    if node not in res_g.nodes():
        res_g.add_node(node)
    for edge in out_edges:
        res_g.add_edge(edge.clone())
        split_down_from(cur_g, edge.to_node, res_g=res_g)
    return res_g
Esempio n. 11
0
    def match_function(self, G: GraphView):
        sub = GraphView()
        sub.add_node(MatchNode('0', matcher=lambda node:\
                isinstance(node, FilterParameters)))
        sub.add_node(MatchNode('1', matcher=lambda node:\
                isinstance(node, MatrixAddParameters)))
        sub.add_node(MatchNode('2', matcher=lambda node:\
                isinstance(node, ConstantInputParameters)))
        sub.add_edge(Edge('0', '1', to_idx=0))
        sub.add_edge(Edge('2', '1', to_idx=1))

        return G.match_fragment(sub)
Esempio n. 12
0
def construct_subgraph(G, nodes):
    """ construct a subgraph from nodes """
    sub_g = GraphView()
    while nodes:
        node = nodes.pop(0)
        if node not in sub_g.nodes():
            sub_g.add_node(node)
        for edge in G.out_edges(node.name):
            if edge.to_node in nodes:
                sub_g.add_edge(edge.clone())
        for edge in G.in_edges(node.name):
            if edge.from_node in nodes:
                sub_g.add_edge(edge.clone())
    return sub_g
Esempio n. 13
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for node_set in self.find_sets(G):
            has_modified_graph = True
            in_edges, out_edges, internal_edges = group_edges(G, node_set)
            frag = GraphView()
            for edge in internal_edges:
                frag.add_edge(edge)
            in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group]
                          for edge_group in in_edges.values()]
            in_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in in_edges
            ]
            out_dims = [
                from_node.out_dims[from_idx]
                for from_node, from_idx in out_edges
            ]
            out_mapping = list(out_edges.keys())
            constant_inputs = [
                node_edge_idx[0] for node_edge_idx in in_edges
                if isinstance(node_edge_idx[0], ConstantInputParameters)
            ]
            LOG.info('matched expression - creating expression %s',
                     self._expr_num)
            expr = ExpressionFusionParameters(f"expr_{self._expr_num}",
                                              subgraph=frag,
                                              input_mapping=in_mapping,
                                              output_mapping=out_mapping,
                                              in_dims=in_dims,
                                              out_dims=out_dims,
                                              constant_inputs=constant_inputs)
            in_edge_mapping = list(in_edges.keys())
            out_edge_mapping = [[(edge.to_node, edge.to_idx)
                                 for edge in edge_set]
                                for edge_set in out_edges.values()]

            G.replace_fragment(
                frag,
                expr,
                frag_in_edges=list(set.union(*in_edges.values())),
                frag_out_edges=list(set.union(*out_edges.values())),
                edge_in_mapping=in_edge_mapping,
                edge_out_mapping=out_edge_mapping)
            self._expr_num += 1

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 14
0
    def _match(self,
               G: GraphView,
               set_identity: bool = True,
               **kwargs) -> bool:
        has_modified_graph = False
        for split_node in set(
            [node for node in G.nodes() if isinstance(node, SplitParameters)]):
            in_edges = G.in_edges(split_node.name)
            if len(in_edges) > 1:
                continue
            in_edge = in_edges[0]
            if not isinstance(in_edge.from_node, ConcatParameters):
                continue
            concat_node = in_edge.from_node
            if len(G.out_edges(concat_node.name)) > 1:
                continue
            if concat_node.transpose_out or split_node.transpose_in:
                continue
            if concat_node.axis != split_node.axis:
                continue
            axis = concat_node.axis
            split_out_sizes = [
                out_shape[axis] for out_shape in split_node.out_shapes
            ]
            if len(split_out_sizes) != len(concat_node.in_dims):
                continue
            if not all(split_out_sizes[idx] == in_dim.shape[axis]
                       for idx, in_dim in enumerate(concat_node.in_dims)):
                continue
            has_modified_graph = True
            LOG.info("removing unnecessary concat/split pair %s/%s",
                     concat_node.name, split_node.name)
            concat_in_edges = G.indexed_in_edges(concat_node.name)
            split_out_edges = G.indexed_out_edges(split_node.name)
            G.remove(split_node)
            G.remove(concat_node)
            for idx, in_edge in enumerate(concat_in_edges):
                for out_edge in split_out_edges[idx]:
                    G.add_edge(
                        NNEdge(from_node=in_edge.from_node,
                               from_idx=in_edge.from_idx,
                               to_node=out_edge.to_node,
                               to_idx=out_edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 15
0
    def match(self, G: GraphView, set_identity: bool = True):
        # Only works for reverses connected to one RNN node
        reverse_nodes = set([
            node for node in G.nodes()
            if (isinstance(node, ReverseParameters)
                and len(G.out_edges(node.name)) == 1 and isinstance(
                    G.out_edges(node.name)[0].to_node, RNNBaseParameters))
        ])

        has_modified_graph = False
        for reverse_node in reverse_nodes:
            in_edges = G.in_edges(reverse_node.name)
            rnn_edge = G.out_edges(reverse_node.name)[0]
            if rnn_edge.to_idx != 0:
                LOG.warning("reverse on rnn input %s", rnn_edge.to_idx)
                continue
            assert not rnn_edge.to_node.revert, "RNN node is already reversed!"
            rnn_edge.to_node.revert = True
            LOG.info("fusing reverses into node %s", rnn_edge.to_node.name)
            has_modified_graph = True
            G.remove(reverse_node)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           rnn_edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=rnn_edge.to_idx))

            for edge in G.out_edges(rnn_edge.to_node.name):
                if not isinstance(edge.to_node, ReverseParameters):
                    continue
                if edge.from_idx != 0:
                    LOG.warning("reverse on rnn output %s", edge.from_idx)
                    continue
                rev_edges = G.out_edges(edge.to_node.name)
                G.remove(edge.to_node)
                for rev_edge in rev_edges:
                    G.add_edge(
                        NNEdge(edge.from_node,
                               rev_edge.to_node,
                               from_idx=edge.from_idx,
                               to_idx=rev_edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 16
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        # collect connected node sets
        node_sets = group_nodes(G, [
            node for node in G.nodes() if isinstance(node, FUSE_NODES) or (
                isinstance(node, ConstantInputParameters)
                and node.out_dims[0].size() == 1)
        ])
        # remove sets that are only ConstantInputs
        node_sets = [
            node_set for node_set in node_sets if not all(
                isinstance(node, ConstantInputParameters) for node in node_set)
        ]
        for node_set in node_sets:
            has_modified_graph = True
            in_edges, out_edges, internal_edges = group_edges(G, node_set)
            frag = GraphView()
            for edge in internal_edges:
                frag.add_edge(edge)
            in_mapping = [[(edge.to_node, edge.to_idx) for edge in edge_group]
                          for edge_group in in_edges.values()]
            out_mapping = list(out_edges.keys())
            constant_inputs = [
                isinstance(node_edge_idx[0], ConstantInputParameters)
                for node_edge_idx in in_edges
            ]
            expr = ExpressionFusionParameters("expr_%s" % self._expr_num,
                                              subgraph=frag,
                                              input_mapping=in_mapping,
                                              output_mapping=out_mapping,
                                              constant_inputs=constant_inputs)
            G.replace_fragment(
                frag,
                expr,
                frag_in_edges=list(set.union(*in_edges.values())),
                frag_out_edges=list(set.union(*out_edges.values())),
                edge_in_mapping=sorted(list(in_edges.keys()),
                                       key=lambda x: x[1]),
                edge_out_mapping=[[(edge.to_node, edge.to_idx)
                                   for edge in edge_set]
                                  for edge_set in out_edges.values()])

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        something_changed = False
        for relu_node in [node for node in G.nodes(node_classes=ReluActivationParameters) if node.upper_bound == 6]:
            out_edges = G.out_edges(relu_node)
            if len(out_edges) != 1 or not isinstance(out_edges[0].to_node, MatrixMulParameters):
                continue
            mul_node = out_edges[0].to_node
            in_edges = G.in_edges(mul_node)
            if len(in_edges) != 2:
                continue
            other_edge = (set(in_edges) - {out_edges[0]}).pop()
            constant_node = other_edge.from_node
            if len(G.out_edges(constant_node)) != 1:
                continue
            if (not isinstance(constant_node, ConstantInputParameters) or
                    not check_equals(G, constant_node, 1.0/6.0)):
                continue

            something_changed = True
            activation = HSigmoidActivationParameters(
                G.unique_name(f'{mul_node.name}_hsigmoid'), offset=0)

            in_edges = G.in_edges(relu_node)
            out_edges = G.out_edges(mul_node)

            nodes_to_replace = [relu_node, mul_node, constant_node]

            LOG.info(f'fusing {", ".join(node.name for node in nodes_to_replace)} into HSIGMOID {activation.name}')
            G.remove_all(nodes_to_replace)

            for in_edge in in_edges:
                G.add_edge(NNEdge.clone(in_edge, to_node=activation, to_idx=0))
            for out_edge in out_edges:
                G.add_edge(NNEdge.clone(
                    out_edge, from_node=activation, from_idx=0))

            if G.quantization:
                reluqrec = G.quantization[NodeId(relu_node)]
                mulqrec = G.quantization[NodeId(mul_node)]
                del G.quantization[NodeId(constant_node)]
                pqrec = QRec.copy_ktype(
                    reluqrec, in_qs=reluqrec.in_qs, out_qs=mulqrec.out_qs)
                G.quantization[NodeId(activation)] = pqrec

        return something_changed
Esempio n. 18
0
    def move_constant(cls, G: GraphView, params, in_qs):
        # looks for a constant on one of the inputs
        # if there is one we can scale by the second dimension of the second
        # tensor. If the constant is on the first tensor then move to the second
        # and transpose the operation
        in_edges = G.indexed_in_edges(params.name)
        in1_node = in_edges[0].from_node
        in2_node = in_edges[1].from_node

        if isinstance(in2_node, ConstantInputParameters):
            return in2_node, in_qs
        elif isinstance(in1_node, ConstantInputParameters):
            if len(params.in_dims) > 2:
                # check if the bias has the correct length to move constant
                # it must have a length equal to the second tensors second dimension after transpose
                bias_size = params.in_dims[2].size()
                in1_shape = params.in_dims[0].shape
                if in1_shape[1] != bias_size:
                    return None, in_qs
            for edge in in_edges[:2:]:
                G.remove_edge(edge)
            to_idx = 1
            # swap edges to move constant onto input 2
            for edge in in_edges[:2:]:
                new_edge = NNEdge(from_node=edge.from_node,
                                  to_node=edge.to_node,
                                  from_idx=edge.from_idx,
                                  to_idx=to_idx)
                G.add_edge(new_edge)
                to_idx = 1 - to_idx
            # use A.B = (BT.AT)T identity
            tin1 = TransposeParameters(G.unique_name(f'{params.name}_tin1'),
                                       transpose=(1, 0))
            tin2 = TransposeParameters(G.unique_name(f'{params.name}_tin2'),
                                       transpose=(1, 0))
            tout = TransposeParameters(G.unique_name(f'{params.name}_tout'),
                                       transpose=(1, 0))
            G.insert_node_before(tin1, params)
            G.insert_node_before(tin2, params, to_idx=1)
            G.insert_node_after(params, tout)
            LOG.warning('transposes inserted on %s - rerun adjust',
                        params.name)
            return in1_node, [in_qs[1], in_qs[0]] + in_qs[2::]
        else:
            return None, in_qs
Esempio n. 19
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        fac = MatScalePairMatchFactory()
        has_modified_graph = False

        for frag, match in fac.get_matcher().match_graph(G):
            match_edges = [
                G.indexed_in_edges(node.name)[idx] for node, idx in match
            ]
            first_node = frag.inputs()[0]
            last_node = frag.outputs()[0]
            out_edges = G.out_edges(last_node.name)
            for node in frag.nodes():
                G.remove(node)

            input_mapping = MatScaleFusionParameters.get_mapping_from_edges(
                match_edges)

            fnode = MatScaleFusionParameters(
                "{}_{}_fusion".format(first_node.name, last_node.name),
                fusion_type="vec_scalar",
                subgraph=frag,
                input_mapping=MatScaleFusionParameters.convert_input_mapping(
                    input_mapping))
            has_modified_graph = True
            G.add_node(fnode)
            fnode.in_dims_hint = [None] * 3

            for idx, edge in enumerate(match_edges):
                new_edge = edge.clone(
                    to_node=fnode,
                    to_idx=list(input_mapping[edge.to_node].keys())[0])
                if new_edge.from_node.out_dims_hint:
                    fnode.in_dims_hint[idx] = new_edge.from_node.out_dims_hint[
                        edge.from_idx]
                G.add_edge(new_edge)
            for edge in out_edges:
                new_edge = edge.clone(from_node=fnode)
                G.add_edge(new_edge)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 20
0
def find_concats_up(G, concat, subgraph: GraphView = None):
    # Produces a subgraph of concats operating on axis 0 separated by copys or reshapes.
    # the output node will be the final concat. the input nodes will be all the inputs
    # to a condensed concat that can replace this subgraph.
    if subgraph is None:
        subgraph = GraphView()
        edge_path = []
    for edge in G.indexed_in_edges(concat.name):
        edge_path = traverse_to_concat(G, edge, subgraph)
        if edge_path:
            for inter_edge in edge_path:
                subgraph.add_edge(inter_edge)
        else:
            subgraph.add_edge(
                NNEdge(from_node=DummyInput(
                    f"{edge.from_node.name}_{edge.from_idx}", edge),
                       to_node=edge.to_node,
                       to_idx=edge.to_idx))
    return subgraph
Esempio n. 21
0
    def match(self, G: GraphView, set_identity: bool = True):
        visited_edges = {}
        nodes_to_remove = []
        has_modified_graph = False
        for node in G.inputs():
            # check if constantinput. if is then check if positive and check max value
            if isinstance(node, ConstantInputParameters):
                if node.value is not None:
                    if G.has_quantized_parameters:
                        qrec = G.quantization[NodeId(node)]
                        qtype = qrec.out_qs[0]
                        if hasattr(qtype, 'wrapped'):
                            qtype = qtype.wrapped
                        val = qtype.dequantize(node.value)
                    else:
                        val = node.value
                    if val.min() >= 0:
                        status = (True, val.max())
                    else:
                        status = (False, False)
            else:
                status = (False, False)

            for edge in G.out_edges(node.name):
                visited_edges[edge] = status
                nodes_to_remove += find_redundant_relus(
                    G, edge.to_node, visited_edges)
        for node in nodes_to_remove:
            has_modified_graph = True
            # Only relus so only one in edge
            in_edge = G.in_edges(node.name)[0]
            for edge in G.out_edges(node.name):
                G.add_edge(
                    NNEdge(from_node=in_edge.from_node,
                           from_idx=in_edge.from_idx,
                           to_node=edge.to_node,
                           to_idx=edge.to_idx))
            G.remove(node)

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 22
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        modified_graph = True
        while modified_graph:
            modified_graph = False
            for reshape in G.nodes(node_classes=(ReshapeParameters, )):
                if not reshape.has_transpose and reshape.shape.shape == reshape.old_shape.shape:
                    modified_graph = True
                    LOG.info('removing reshape that does nothing %s',
                             reshape.name)
                    G.remove_and_reconnect(reshape, edge_class=NNEdge)
                    nid = NodeId(reshape)
                    if G.quantization and nid in G.quantization:
                        del G.quantization[nid]
            res = None
            for reshape in G.nodes(node_classes=(ReshapeParameters, )):
                res = self.validate_reshape(G, reshape)
                if res:
                    LOG.info('unnecessary reshape found after %s',
                             reshape.name)
                    modified_graph = True
                    (reshape, candidates, out_shape) = res
                    for candidate in candidates:
                        LOG.info(
                            'removing unnecessary reshape or transpose %s',
                            candidate.name)
                        edges = G.out_edges(candidate.name)
                        G.remove(candidate)
                        nid = NodeId(candidate)
                        if G.quantization and nid in G.quantization:
                            del G.quantization[nid]
                        for edge in edges:
                            G.add_edge(
                                NNEdge(from_node=reshape,
                                       to_node=edge.to_node,
                                       to_idx=edge.to_idx))
                    reshape.shape = Dim.unnamed(out_shape)
                    break

        if set_identity:
            self.set_identity(G)

        return modified_graph
Esempio n. 23
0
    def match_function(self, G: GraphView):
        sub = GraphView()
        sub.add_node(
            MatchNode(
                '0',
                matcher=lambda node: isinstance(node, ReluActivationParameters
                                                ) and node.upper_bound == 6))
        sub.add_node(
            MatchNode(
                '1',
                matcher=lambda node: isinstance(node, MatrixMulParameters)))
        sub.add_node(
            MatchNode(
                '2',
                matcher=lambda node: isinstance(node, ConstantInputParameters)
                and check_equals(G, node, 1.0 / 6.0)))
        sub.add_edge(Edge('0', '1', to_idx=0))
        sub.add_edge(Edge('2', '1', to_idx=1))

        return G.match_fragment(sub)
Esempio n. 24
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False

        for node in G.nodes(node_classes=tuple(VALID_FUSIONS.keys())):
            node_list = self.get_node_list(G, node,
                                           FusionMatch(self._default_ktype))
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for snode in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=snode))
                last_node = snode
            # assumption here is that the first node could have multiple inputs but definitely has only
            # one output
            input_mapping = [[
                (node_list.node, idx)
            ] for idx in range(G.num_in_edges(node_list.node.name))]
            output_mapping = [(last_node, 0)]
            pnode = node_list.fusions_class(node_list.node.name + '_fusion',
                                            fusion_type=node_list.fusion_type,
                                            subgraph=subgraph,
                                            input_mapping=input_mapping,
                                            output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(qrecs[0],
                                           in_qs=qrecs[0].in_qs,
                                           out_qs=qrecs[-1].out_qs)
                    for fnode in pnode.contained_nodes():
                        G.quantization.move_to_fusion(fnode, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.node.name)
            out_edges = G.out_edges(last_node.name)
            for snode in node_list.order:
                G.remove(snode)
            for edge in in_edges:
                G.add_edge(
                    NNEdge(edge.from_node,
                           pnode,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(
                    NNEdge(pnode,
                           edge.to_node,
                           from_idx=edge.from_idx,
                           to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        for node in G.nodes(node_classes=SplitParameters):
            same_op_edges = self.moveable_same_operation_edges(G, node)
            if not same_op_edges:
                continue
            has_modified_graph = True
            in_edges = G.in_edges(node.name)
            assert len(in_edges) == 1
            # sort by name to ensure that operation is repeatable
            same_op_edges.sort(key=lambda x: x.to_node.name)
            keep_node = same_op_edges[0].to_node
            LOG.info('split node %s has duplicate operations on its out edges',
                     node.name)
            LOG.info('moving %s before split node %s', keep_node.name,
                     node.name)
            for edge in G.out_edges(node.name):
                node_out_edges = G.out_edges(edge.to_node.name)
                G.remove(edge.to_node)
                if edge.to_node != keep_node:
                    LOG.info('deleting duplicate node %s', edge.to_node.name)
                    if G.quantization:
                        nid = NodeId(edge.to_node)
                        if nid in G.quantization:
                            del G.quantization[nid]
                for out_edge in node_out_edges:
                    G.add_edge(
                        NNEdge(from_node=node,
                               from_idx=edge.from_idx,
                               to_node=out_edge.to_node,
                               to_idx=out_edge.to_idx))
            G.insert_node_at_edge(keep_node, in_edges[0], edge_class=NNEdge)
            if G.quantization:
                quantizer = NewQuantizer.from_quantized_graph(G)
                quantizer.quantize()

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 26
0
    def match(self, G: GraphView, set_identity: bool = True):
        has_modified_graph = False
        for conv_node in [params for params in G.nodes() if isinstance(params, Conv2DParameters)]:
            node_list = self.get_node_list(G, conv_node)
            if node_list is None or len(node_list.order) < 2:
                continue
            if node_list.fusion_type == 'conv_active_pool':
                if node_list.pool.pool_type == "average":
                    node_list.order = node_list.order[:2:]
                    node_list.pool = None
            elif node_list.fusion_type == 'conv_pool_active':
                if node_list.pool.pool_type == "average" and node_list.active.activation != "relu":
                    continue
            LOG.info("fusing nodes %s", ",".join((node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.conv, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = ConvFusionParameters(
                node_list.conv.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                in_dims_hint=node_list.conv.in_dims_hint,
                out_dims_hint=node_list.conv.out_dims_hint,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = None
                    if isinstance(qrecs[0], (SymmetricQuantizationRecord, SymmetricScalableFilterQuantizationRecord)):
                        prec = SymmetricQuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (MultQuantizationRecord, MultScalableFilterQuantizationRecord)):
                        prec = MultQuantizationRecord(in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    elif isinstance(qrecs[0], (Float32QuantizationRecord, Float32ScalableFilterQuantizationRecord)):
                        prec = Float32QuantizationRecord(
                            in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.conv.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode, from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node, from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 27
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        visited_edges = {}
        nodes_to_remove = []
        has_modified_graph = False
        for node in G.inputs():
            # check if constantinput. if is then check if positive and check max value
            if isinstance(node, ConstantInputParameters):
                if node.value is not None:
                    val = node.dqvalue
                    if np.min(val) >= 0:
                        status = (True, np.max(val))
                    else:
                        status = (False, False)
                else:
                    status = (False, False)
            else:
                status = (False, False)

            for edge in G.out_edges(node.name):
                visited_edges[edge] = status
                nodes_to_remove += find_redundant_relus(
                    G, edge.to_node, visited_edges)
        for node in nodes_to_remove:
            has_modified_graph = True
            # Only relus so only one in edge
            LOG.info("removing redundant relu %s", node.name)
            in_edge = G.in_edges(node.name)[0]
            out_edges = G.out_edges(node.name)
            G.remove(node)
            for edge in out_edges:
                G.add_edge(NNEdge(from_node=in_edge.from_node,
                                  from_idx=in_edge.from_idx,
                                  to_node=edge.to_node,
                                  to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph
Esempio n. 28
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        rnn_nodes = [
            self.find_unpack(G, node) for node in G.nodes()
            if isinstance(node, RNNBaseParameters) and node.n_output_cells > 1
        ]
        rnn_nodes_by_slice = self.validate_slices(G, rnn_nodes)
        rnn_nodes_by_slice = self.validate_multi_branch(G, rnn_nodes_by_slice)
        if not rnn_nodes_by_slice:
            return False

        for unpack_node, rnn_unpacks in rnn_nodes_by_slice.items():
            modified_nodes = set()
            for rnn_unpack in rnn_unpacks:
                self.process_path(G, rnn_unpack, modified_nodes)
            # since process path will have removed all unnecessary nodes the edges will be correct here
            out_edges = G.out_edges(unpack_node.name)
            in_edges = G.in_edges(unpack_node.name)
            assert len(in_edges
                       ) == 1, "expecting unpack node to have only one in edge"
            in_edge = in_edges[0]
            changes_shape = unpack_node.changes_shape if isinstance(
                unpack_node, StridedSliceParameters) else False

            LOG.info("Eliminating last cell unpack: %s", unpack_node.name)
            G.remove(unpack_node)

            # Here the strided slice can change the output shape of the RNN
            # so insert a reshape to do the shape change
            if changes_shape:
                reshape = ReshapeParameters(
                    unpack_node.name + '_reshape',
                    old_shape=Dim.unnamed(unpack_node.post_slice_shape),
                    shape=Dim.unnamed(unpack_node.out_shape))
                G.add_edge(
                    NNEdge(from_node=in_edge.from_node,
                           to_node=reshape,
                           from_idx=in_edge.from_idx))
                for out_edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=reshape,
                               to_node=out_edge.to_node,
                               to_idx=out_edge.to_idx))
                if G.quantization:
                    G.quantization[NodeId(reshape)] = G.quantization[NodeId(
                        unpack)]
            else:
                for out_edge in out_edges:
                    G.add_edge(
                        NNEdge(from_node=in_edge.from_node,
                               to_node=out_edge.to_node,
                               from_idx=in_edge.from_idx,
                               to_idx=out_edge.to_idx))
            if G.quantization:
                del G.quantization[NodeId(unpack_node)]

        if set_identity:
            self.set_identity(G)

        return True
Esempio n. 29
0
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        nodes = list(G.nodes(node_classes=GlobalPoolingParameters))
        modified_graph = False
        while nodes:
            node = nodes.pop()
            node_group = self.reductions(G, node)
            if len(node_group) <= 1:
                continue
            modified_graph = True
            reduction_axes, new_shape, has_keepdims, _ = reduce(
                reduce_reduction, node_group, None)
            new_node = node_group[0]
            new_node.axis = sorted(list(reduction_axes))
            new_node.keep_dims = has_keepdims
            out_edges = G.out_edges(node_group[-1].name)
            if G.quantization:
                last_qrec = G.quantization[NodeId(node_group[-1])]
                G.quantization[NodeId(new_node)].out_qs = last_qrec.out_qs
            for node in node_group[1::]:
                G.remove(node.name)
                nid = NodeId(node)
                if G.quantization and nid in G.quantization:
                    del G.quantization[nid]
            if has_keepdims and len(new_shape) != len(
                    new_node.in_dims[0].shape):
                rparams = ReshapeParameters(
                    G.unique_name(f'{new_node.name}_reshape'),
                    shape=Dim.unnamed(new_shape))
                if G.quantization:
                    G.quantization.copy_qrec(last_qrec, 'out', 0, rparams)
                G.add_edge(NNEdge(new_node, rparams))
                new_node = rparams
            for edge in out_edges:
                G.add_edge(NNEdge(new_node, edge.to_node, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return modified_graph
    def _match(self, G: GraphView, set_identity: bool = True, **kwargs):
        has_modified_graph = False
        group_identity = kwargs.get('group_identity')
        if group_identity == 'pow2_match_group':
            valid_activations = VALID_ACTIVATIONS_POW2
        else:
            valid_activations = VALID_ACTIVATIONS_SQ8
        for fc_node in [params for params in G.nodes() if isinstance(params, FcParameters)]:
            node_list = self.get_node_list(G, fc_node, valid_activations)
            if node_list is None or len(node_list.order) < 2:
                continue
            LOG.info("fusing nodes %s", ",".join(
                (node.name for node in node_list.order)))
            has_modified_graph = True
            subgraph = GraphView()
            last_node = None
            for node in node_list.order:
                if last_node is not None:
                    subgraph.add_edge(
                        NNEdge(from_node=last_node, to_node=node))
                last_node = node
            input_mapping = [[(node_list.linear, idx)] for idx in range(3)]
            output_mapping = [(last_node, 0)]
            pnode = LinearFusionParameters(
                node_list.linear.name + '_fusion',
                fusion_type=node_list.fusion_type,
                subgraph=subgraph,
                input_mapping=input_mapping,
                output_mapping=output_mapping)
            if G.quantization:
                # TODO - stats
                qrecs = G.quantization.get_all(pnode.contained_nodes())
                if qrecs:
                    prec = QRec.copy_ktype(
                        qrecs[0], in_qs=qrecs[0].in_qs, out_qs=qrecs[-1].out_qs)
                    for node in pnode.contained_nodes():
                        G.quantization.move_to_fusion(node, pnode)
                    G.quantization[NodeId(pnode)] = prec
            in_edges = G.in_edges(node_list.linear.name)
            out_edges = G.out_edges(last_node.name)
            for node in node_list.order:
                G.remove(node)
            for edge in in_edges:
                G.add_edge(NNEdge(edge.from_node, pnode,
                                  from_idx=edge.from_idx, to_idx=edge.to_idx))
            for edge in out_edges:
                G.add_edge(NNEdge(pnode, edge.to_node,
                                  from_idx=edge.from_idx, to_idx=edge.to_idx))

        if set_identity:
            self.set_identity(G)

        return has_modified_graph