示例#1
0
    def extract(cls, node):
        attrs = {
            'axis': np.array(onnx_attr(node, 'axis', 'i', default=0), dtype=np.int64)
        }

        AttributedGather.update_node_stat(node, attrs)
        return cls.enabled
示例#2
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='Gather'):
            name = node.soft_get('name', node.id)
            assert 2 in node.in_ports() and not node.in_port(2).disconnected()
            assert not node.has_valid('axis')

            axis = node.in_port(2).data.get_value()
            assert axis is not None

            attributed_gather = AttributedGather(graph, {
                'axis': axis,
                'name': name
            }).create_node()

            node.out_port(0).get_connection().set_source(
                attributed_gather.out_port(0))
            node.in_port(0).get_connection().set_destination(
                attributed_gather.in_port(0))
            node.in_port(1).get_connection().set_destination(
                attributed_gather.in_port(1))

            # shape inference (before cleaning this node up) will fail due to disconnected input ports
            node['need_shape_inference'] = False
示例#3
0
 def extract(cls, node):
     AttributedGather.update_node_stat(node, {'axis': 0})
     return cls.enabled
 def extract(cls, node: Node):
     attrs = get_mxnet_layer_attrs(node.symbol_dict)
     AttributedGather.update_node_stat(node, {
         'axis': attrs.int('axis', 0),
     })
     return cls.enabled