Exemplo n.º 1
0
 def replace_pattern(graph: Graph, match: dict):
     node = match['transpose']
     order = node.in_port(1).data.get_value()
     assert order is not None
     Permute.update_node_stat(node=node, attrs={'order': order.copy()})
     node['force_precision_in_ports'] = None
     node.in_port(1).disconnect()
Exemplo n.º 2
0
 def extract(node):
     pb = node.parameters
     weights_size = read_binary_integer32_token(pb)
     weights = read_blob(pb, weights_size, dtype=np.int32) - 1
     attrs = {'infer': copy_shape_infer}
     embed_input(attrs, 1, 'indexes', weights)
     Permute.update_node_stat(node, attrs)
     return __class__.enabled
Exemplo n.º 3
0
    def extract(node):
        # In case of undefined 'perm' attribute, Transpose operation in ONNX reverse the dimensions
        order = onnx_attr(node, 'perm', 'ints', default=None)
        attrs = {
            'order':
            np.array(order, dtype=np.int64) if order is not None else None,
            'reverse_order': order is None
        }

        # update the attributes of the node
        Permute.update_node_stat(node, attrs)
        return __class__.enabled