コード例 #1
0
 def test_4d(self):
     graph = build_graph(nodes_attrs=graph_node_attrs,
                         edges=graph_edges,
                         update_attributes={
                             'placeholder_data': {
                                 'shape': int64_array([1, 8, 32, 32])
                             },
                             'unsqueeze_data': {
                                 'shape': int64_array([1, 8, 1, 32, 32])
                             },
                             'multipliers': {
                                 'value': int64_array([1, 1, 2, 1, 1]),
                                 'shape': int64_array([5])
                             },
                             'multipliers_data': {
                                 'value': int64_array([1, 1, 2, 1, 1]),
                                 'shape': int64_array([5])
                             },
                             'tile_data': {
                                 'shape': int64_array([1, 8, 2, 32, 32])
                             },
                             'reshape_data': {
                                 'shape': int64_array([1, 16, 32, 32]),
                                 'value': None
                             },
                             'shape': {
                                 'value': int64_array([1, 16, 32, 32]),
                                 'shape': int64_array([4])
                             },
                             'shape_data': {
                                 'value': int64_array([1, 16, 32, 32]),
                                 'shape': int64_array([4])
                             },
                             'abs_data': {
                                 'shape': int64_array([1, 16, 32, 32])
                             },
                         })
     ref_graph = build_graph(
         nodes_attrs=ref_graph_node_attrs_with_4_inputs_interpolate,
         edges=ref_graph_edges_attrs_with_4_inputs_interpolate,
         update_attributes={
             'placeholder_data': {
                 'shape': int64_array([1, 8, 32, 32])
             },
             'interpolate_data': {
                 'shape': int64_array([1, 16, 32, 32])
             },
             'abs_data': {
                 'shape': int64_array([1, 16, 32, 32])
             },
             'axes': {
                 'shape': int64_array([1]),
                 'value': int64_array([1])
             },
         })
     UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern(
         graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)
コード例 #2
0
 def test_5d(self):
     graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges)
     ref_graph = build_graph(
         nodes_attrs=ref_graph_node_attrs_with_4_inputs_interpolate,
         edges=ref_graph_edges_attrs_with_4_inputs_interpolate)
     UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern(
         graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)
コード例 #3
0
 def test_3d(self):
     graph = build_graph(
         nodes_attrs=graph_node_attrs_when_transformation_is_not_applicable,
         edges=graph_edges_when_transformation_is_not_applicable)
     ref_graph = build_graph(
         nodes_attrs=graph_node_attrs_when_transformation_is_not_applicable,
         edges=graph_edges_when_transformation_is_not_applicable)
     UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern(
         graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)
コード例 #4
0
 def test_2d(self):
     graph = build_graph(
         nodes_attrs=graph_node_attrs_when_transformation_is_not_applicable,
         edges=graph_edges_when_transformation_is_not_applicable,
         update_attributes={
             'placeholder_data': {
                 'shape': int64_array([5, 8])
             },
             'dim': {
                 'value': int64_array([1])
             },
             'dim_data': {
                 'value': int64_array([1])
             },
             'unsqueeze_data': {
                 'shape': int64_array([5, 1, 8])
             },
             'multipliers': {
                 'value': int64_array([1, 10, 1])
             },
             'multipliers_data': {
                 'value': int64_array([1, 10, 1]),
                 'shape': int64_array([3])
             },
             'tile_data': {
                 'shape': int64_array([5, 10, 8])
             },
             'reshape_data': {
                 'shape': int64_array([50, 8])
             },
             'shape': {
                 'value': int64_array([50, 8]),
                 'shape': int64_array([2])
             },
             'shape_data': {
                 'value': int64_array([50, 8]),
                 'shape': int64_array([2])
             },
             'abs_data': {
                 'shape': int64_array([50, 8])
             },
         })
     ref_graph = build_graph(
         nodes_attrs=graph_node_attrs_when_transformation_is_not_applicable,
         edges=graph_edges_when_transformation_is_not_applicable,
         update_attributes={
             'placeholder_data': {
                 'shape': int64_array([5, 8])
             },
             'dim': {
                 'value': int64_array([1])
             },
             'dim_data': {
                 'value': int64_array([1])
             },
             'unsqueeze_data': {
                 'shape': int64_array([5, 1, 8])
             },
             'multipliers': {
                 'value': int64_array([1, 10, 1])
             },
             'multipliers_data': {
                 'value': int64_array([1, 10, 1]),
                 'shape': int64_array([3])
             },
             'tile_data': {
                 'shape': int64_array([5, 10, 8])
             },
             'reshape_data': {
                 'shape': int64_array([50, 8])
             },
             'shape': {
                 'value': int64_array([50, 8]),
                 'shape': int64_array([2])
             },
             'shape_data': {
                 'value': int64_array([50, 8]),
                 'shape': int64_array([2])
             },
             'abs_data': {
                 'shape': int64_array([50, 8])
             },
         })
     UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern(
         graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)