Пример #1
0
 def match(self, G: GraphView, set_identity: bool = True):
     # get a list of all the nodes that are transposable but not transposes
     # Need to do this first to avoid mutating it when doing the modifications
     tnodes = list(filter(lambda n: isinstance(n, Transposable) and\
                             not isinstance(n, TransposeParameters),
                          G.nodes()))
     for node in tnodes:
         if node.transpose_in:
             for idx, edge in enumerate(G.in_edges(node.name)):
                 in_params = TransposeParameters("%s_TIN_%s" % (node.name, idx),
                                                 transpose=node.transpose_in)
                 if node.in_dims_hint:
                     in_hint = node.in_dims_hint[edge.to_idx]
                     out_hint = apply_reverse_transpose_to_hint(in_hint, node.transpose_in)
                     in_params.in_dims_hint = [in_hint.copy()]
                     in_params.out_dims_hint = [out_hint.copy()]
                     node.in_dims_hint[edge.to_idx] = out_hint
                 G.insert_node(in_params, edge.from_node.name, edge.to_node.name,
                               from_idx=edge.from_idx, to_idx=edge.to_idx)
             node.transpose_in = None
         if node.transpose_out:
             for idx, edge in enumerate(G.out_edges(node.name)):
                 out_params = TransposeParameters("%s_TOUT_%s" % (node.name, idx),
                                                  transpose=node.transpose_out)
                 if node.out_dims_hint:
                     out_hint = node.out_dims_hint[edge.from_idx]
                     in_hint = apply_reverse_transpose_to_hint(out_hint, node.transpose_out)
                     out_params.in_dims_hint = [in_hint.copy()]
                     out_params.out_dims_hint = [out_hint.copy()]
                     node.out_dims_hint[edge.from_idx] = in_hint
                 G.insert_node(out_params, edge.from_node.name, edge.to_node.name,
                               from_idx=edge.from_idx, to_idx=edge.to_idx)
             node.transpose_out = None
     if set_identity:
         self.set_identity(G)
Пример #2
0
 def match(self, G: GraphView, set_identity: bool = True):
     # get a list of all the nodes that are transposable but not transposes
     # Need to do this first to avoid mutating it when doing the modifications
     tnodes = list(filter(lambda n: isinstance(n, Transposable) and
                          not isinstance(n, TransposeParameters),
                          G.nodes()))
     has_modified_graph = False
     for node in tnodes:
         if node.transpose_in:
             for idx, edge in enumerate(G.in_edges(node.name)):
                 if edge.to_idx >= len(node.transpose_in):
                     continue
                 trans = node.transpose_in[edge.to_idx]
                 if trans is None:
                     continue
                 has_modified_graph = True
                 in_params = TransposeParameters("%s_TIN_%s" % (node.name, idx),
                                                 transpose=trans)
                 if node.in_dims_hint and node.in_dims_hint[edge.to_idx]:
                     in_hint = node.in_dims_hint[edge.to_idx]
                     out_hint = apply_reverse_transpose_to_hint(in_hint, trans)
                     in_params.in_dims_hint = [in_hint.copy()]
                     in_params.out_dims_hint = [out_hint.copy()]
                     node.in_dims_hint[edge.to_idx] = out_hint
                 if G.quantization:
                     G.quantization.copy_to_node(node, in_params)
                 G.insert_node(in_params, edge.from_node.name, edge.to_node.name,
                               from_idx=edge.from_idx, to_idx=edge.to_idx,
                               edge_class=NNEdge)
             node.transpose_in = None
         if node.transpose_out:
             for idx, edge in enumerate(G.out_edges(node.name)):
                 if edge.from_idx >= len(node.transpose_out):
                     continue
                 trans = node.transpose_out[edge.from_idx]
                 if trans is None:
                     continue
                 has_modified_graph = True
                 out_params = TransposeParameters("%s_TOUT_%s" % (node.name, idx),
                                                  transpose=trans)
                 if node.out_dims_hint:
                     out_hint = node.out_dims_hint[edge.from_idx]
                     in_hint = apply_reverse_transpose_to_hint(out_hint, trans)
                     out_params.in_dims_hint = [in_hint.copy()]
                     out_params.out_dims_hint = [out_hint.copy()]
                     node.out_dims_hint[edge.from_idx] = in_hint
                 if G.quantization:
                     G.quantization.copy_to_node(node, out_params)
                 G.insert_node(out_params, edge.from_node.name, edge.to_node.name,
                               from_idx=edge.from_idx, to_idx=edge.to_idx,
                               edge_class=NNEdge)
             node.transpose_out = None
     if set_identity:
         self.set_identity(G)
     return has_modified_graph