Beispiel #1
0
 def test_spatial_3d_split_concat_2(self):
     graph = build_graph(
         nodes_attrs=graph_node_attrs_for_3d_spatial_case,
         edges=graph_edges,
         update_attributes={
             'split': {'type': 'Split', 'kind': 'op', 'op': 'Split', 'num_splits': 3},
             'split_axis_const': {
                 'kind': 'op',
                 'value': np.array(3, dtype=np.int64),
                 'op': 'Const',
                 'type': 'Const'
             },
             'split_axis_const_data': {
                 'value': np.array(3, dtype=np.int64),
                 'shape': np.array(3, dtype=np.int64).shape,
                 'kind': 'data'
             },
             'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
             'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
             'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
             'split_data_2': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
             'concat_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
             'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
         }
     )
     ref_graph = build_graph(
         nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_2,
         edges=ref_graph_edges_opset4
     )
     SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)
 def test_spatial_2d_split_concat_1(self):
     graph = build_graph(nodes_attrs=graph_node_attrs_for_2d_spatial_case,
                         edges=graph_edges)
     ref_graph = build_graph(
         nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_1_opset4,
         edges=ref_graph_edges_opset4)
     SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)
 def test_two_splits_one_concat(self):
     graph = build_graph(
         nodes_attrs=graph_node_attrs_when_there_are_two_splits_one_concat,
         edges=graph_edges_when_there_are_two_splits_one_concat)
     ref_graph = build_graph(
         nodes_attrs=graph_node_attrs_when_there_are_two_splits_one_concat,
         edges=graph_edges_when_there_are_two_splits_one_concat)
     SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)
Beispiel #4
0
 def test_spatial_3d_split_concat_1(self):
     graph = build_graph(
         nodes_attrs=graph_node_attrs_for_3d_spatial_case,
         edges=graph_edges
     )
     ref_graph = build_graph(
         nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_1,
         edges=ref_graph_edges_opset4,
         update_attributes={
             'axes': {'shape': int64_array([1]), 'value': int64_array([4])}
         }
     )
     SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)