def test_spatial_3d_split_concat_2(self): graph = build_graph( nodes_attrs=graph_node_attrs_for_3d_spatial_case, edges=graph_edges, update_attributes={ 'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3}, 'split_axis_const': { 'kind': 'op', 'value': np.array(3, dtype=np.int64), 'op': 'Const', 'type': 'Const' }, 'split_axis_const_data': { 'value': np.array(3, dtype=np.int64), 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data' }, 'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3}, 'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'}, 'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'}, 'split_data_2': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'}, 'concat_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'}, 'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'}, } ) ref_graph = build_graph( nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_2, edges=ref_graph_edges_opset4 ) SplitConcatPairToInterpolate().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'output') self.assertTrue(flag, resp)
def test_spatial_2d_split_concat_1(self): graph = build_graph(nodes_attrs=graph_node_attrs_for_2d_spatial_case, edges=graph_edges) ref_graph = build_graph( nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_1_opset4, edges=ref_graph_edges_opset4) SplitConcatPairToInterpolate().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'output') self.assertTrue(flag, resp)
def test_two_splits_one_concat(self): graph = build_graph( nodes_attrs=graph_node_attrs_when_there_are_two_splits_one_concat, edges=graph_edges_when_there_are_two_splits_one_concat) ref_graph = build_graph( nodes_attrs=graph_node_attrs_when_there_are_two_splits_one_concat, edges=graph_edges_when_there_are_two_splits_one_concat) SplitConcatPairToInterpolate().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'output') self.assertTrue(flag, resp)
def test_spatial_3d_split_concat_1(self): graph = build_graph( nodes_attrs=graph_node_attrs_for_3d_spatial_case, edges=graph_edges ) ref_graph = build_graph( nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_1, edges=ref_graph_edges_opset4, update_attributes={ 'axes': {'shape': int64_array([1]), 'value': int64_array([4])} } ) SplitConcatPairToInterpolate().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'output') self.assertTrue(flag, resp)