def extract(cls, node): attrs = { 'num_splits': node.module.num_splits, 'axis': node.module.dim, } AttributedSplit.update_node_stat(node, attrs) return cls.enabled
def extract(cls, node: Node): pb = node.pb AttributedSplit.update_node_stat(node, { 'axis': pb.attr['axis'].i, 'num_splits': pb.attr['num'].i, 'squeeze_axis': True, }) return cls.enabled
def test_split_dynamic_shape_infer(self): # test configuration input_shape = [2, dynamic_dimension_value] input_value = None axis = 1 num_splits = 2 output_shape = [2, dynamic_dimension_value] output_value = [None, None] # action graph = build_graph( self.nodes, self.edges, { 'split_input_data': { 'shape': shape_array(input_shape), 'value': input_value }, 'split_op': { 'axis': np.array(axis), 'num_splits': np.array(num_splits) }, }) split_op = Node(graph, 'split_op') AttributedSplit.infer(split_op) # reference graph_ref = build_graph( self.nodes, self.edges, { 'split_input_data': { 'shape': shape_array(input_shape), 'value': input_value }, 'split_op': { 'axis': np.array(axis), 'num_splits': np.array(num_splits) }, 'split_output_0_data': { 'shape': shape_array(output_shape), 'value': output_value[0] }, 'split_output_1_data': { 'shape': shape_array(output_shape), 'value': output_value[1] }, }) # check (flag, resp) = compare_graphs(graph, graph_ref, 'split_input_data') self.assertTrue(flag, resp) self.assertTrue( strict_compare_tensors( Node(graph, 'split_output_0_data').shape, shape_array(output_shape)))
def test_split_value_infer(self): # test configuration input_shape = [2, 10] input_value = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]] axis = 1 num_splits = 2 output_shape = [2, 5] output_value = [[[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]], [[5, 6, 7, 8, 9], [15, 16, 17, 18, 19]]] # action graph = build_graph( self.nodes, self.edges, { 'split_input_data': { 'shape': int64_array(input_shape), 'value': int64_array(input_value) }, 'split_op': { 'axis': np.array(axis), 'num_splits': np.array(num_splits) }, }) split_op = Node(graph, 'split_op') AttributedSplit.infer(split_op) # reference graph_ref = build_graph( self.nodes, self.edges, { 'split_input_data': { 'shape': int64_array(input_shape), 'value': int64_array(input_value) }, 'split_op': { 'axis': np.array(axis), 'num_splits': np.array(num_splits) }, 'split_output_0_data': { 'shape': int64_array(output_shape), 'value': int64_array(output_value[0]) }, 'split_output_1_data': { 'shape': int64_array(output_shape), 'value': int64_array(output_value[1]) }, }) # check (flag, resp) = compare_graphs(graph, graph_ref, 'split_input_data') self.assertTrue(flag, resp)
def extract(cls, node): axis = onnx_attr(node, 'axis', 'i', default=0, dst_type=np.int64) size_splits = onnx_attr(node, 'split', 'ints', default=None, dst_type=int64_array) if size_splits is None: AttributedSplit.update_node_stat(node, { 'axis': axis, 'num_splits': onnx_get_num_outputs(node), }) else: AttributedVariadicSplit.update_node_stat(node, { 'axis': axis, 'size_splits': size_splits, }) return cls.enabled
def extract(cls, node): attrs = get_mxnet_layer_attrs(node.symbol_dict) axis = attrs.int("axis", 1) num_outputs = attrs.int("num_outputs", 0) squeeze_axis = attrs.bool('squeeze_axis', False) node_attrs = { 'axis': axis, 'squeeze_axis': squeeze_axis, 'num_splits': num_outputs, } # update the attributes of the node AttributedSplit.update_node_stat(node, node_attrs) return cls.enabled