def test_split_reverse_infer(self): ref_input_shape = [7, 4, 6] axis = 2 num_splits = 2 output_shape_1 = [dynamic_dimension, 4, 3] output_shape_2 = [7, dynamic_dimension, 3] graph = build_graph( TestSplitOp.nodes, TestSplitOp.edges, { 'split_input_data': { 'shape': None, 'value': None }, 'split_op': { 'axis': np.array(axis), 'num_splits': np.array(num_splits) }, 'split_output_0_data': { 'shape': shape_array(output_shape_1), 'value': None }, 'split_output_1_data': { 'shape': shape_array(output_shape_2), 'value': None }, }) split_node = Node(graph, 'split_op') AttributedSplit.reverse_infer(split_node) actual_input_shape = split_node.in_port(0).data.get_shape() self.assertTrue( strict_compare_tensors(ref_input_shape, actual_input_shape))
def extract(cls, node): attrs = { 'num_splits': node.module.num_splits, 'axis': node.module.dim, } AttributedSplit.update_node_stat(node, attrs) return cls.enabled
def extract(cls, node: Node): pb = node.pb AttributedSplit.update_node_stat( node, { 'axis': pb.attr['axis'].i, 'num_splits': pb.attr['num'].i, 'squeeze_axis': True, }) return cls.enabled
def test_split_dynamic_shape_infer(self): # test configuration input_shape = [2, dynamic_dimension_value] input_value = None axis = 1 num_splits = 2 output_shape = [2, dynamic_dimension_value] output_value = [None, None] # action graph = build_graph( self.nodes, self.edges, { 'split_input_data': { 'shape': shape_array(input_shape), 'value': input_value }, 'split_op': { 'axis': np.array(axis), 'num_splits': np.array(num_splits) }, }) split_op = Node(graph, 'split_op') AttributedSplit.infer(split_op) # reference graph_ref = build_graph( self.nodes, self.edges, { 'split_input_data': { 'shape': shape_array(input_shape), 'value': input_value }, 'split_op': { 'axis': np.array(axis), 'num_splits': np.array(num_splits) }, 'split_output_0_data': { 'shape': shape_array(output_shape), 'value': output_value[0] }, 'split_output_1_data': { 'shape': shape_array(output_shape), 'value': output_value[1] }, }) # check (flag, resp) = compare_graphs(graph, graph_ref, 'split_input_data') self.assertTrue(flag, resp) self.assertTrue( strict_compare_tensors( Node(graph, 'split_output_0_data').shape, shape_array(output_shape)))
def test_split_value_infer(self): # test configuration input_shape = [2, 10] input_value = [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]] axis = 1 num_splits = 2 output_shape = [2, 5] output_value = [[[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]], [[5, 6, 7, 8, 9], [15, 16, 17, 18, 19]]] # action graph = build_graph( self.nodes, self.edges, { 'split_input_data': { 'shape': int64_array(input_shape), 'value': int64_array(input_value) }, 'split_op': { 'axis': np.array(axis), 'num_splits': np.array(num_splits) }, }) split_op = Node(graph, 'split_op') AttributedSplit.infer(split_op) # reference graph_ref = build_graph( self.nodes, self.edges, { 'split_input_data': { 'shape': int64_array(input_shape), 'value': int64_array(input_value) }, 'split_op': { 'axis': np.array(axis), 'num_splits': np.array(num_splits) }, 'split_output_0_data': { 'shape': int64_array(output_shape), 'value': int64_array(output_value[0]) }, 'split_output_1_data': { 'shape': int64_array(output_shape), 'value': int64_array(output_value[1]) }, }) # check (flag, resp) = compare_graphs(graph, graph_ref, 'split_input_data') self.assertTrue(flag, resp)
def extract(cls, node): attrs = get_mxnet_layer_attrs(node.symbol_dict) axis = attrs.int("axis", 1) num_outputs = attrs.int("num_outputs", 0) squeeze_axis = attrs.bool('squeeze_axis', False) node_attrs = { 'axis': axis, 'squeeze_axis': squeeze_axis, 'num_splits': num_outputs, } # update the attributes of the node AttributedSplit.update_node_stat(node, node_attrs) return cls.enabled
def extract(cls, node): axis = onnx_attr(node, 'axis', 'i', default=0, dst_type=np.int64) size_splits = onnx_attr(node, 'split', 'ints', default=None, dst_type=int64_array) if size_splits is None: AttributedSplit.update_node_stat( node, { 'axis': axis, 'num_splits': onnx_get_num_outputs(node), }) else: AttributedVariadicSplit.update_node_stat( node, { 'axis': axis, 'size_splits': size_splits, }) return cls.enabled