def test_4d(self): graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges, update_attributes={ 'placeholder_data': { 'shape': int64_array([1, 8, 32, 32]) }, 'unsqueeze_data': { 'shape': int64_array([1, 8, 1, 32, 32]) }, 'multipliers': { 'value': int64_array([1, 1, 2, 1, 1]), 'shape': int64_array([5]) }, 'multipliers_data': { 'value': int64_array([1, 1, 2, 1, 1]), 'shape': int64_array([5]) }, 'tile_data': { 'shape': int64_array([1, 8, 2, 32, 32]) }, 'reshape_data': { 'shape': int64_array([1, 16, 32, 32]), 'value': None }, 'shape': { 'value': int64_array([1, 16, 32, 32]), 'shape': int64_array([4]) }, 'shape_data': { 'value': int64_array([1, 16, 32, 32]), 'shape': int64_array([4]) }, 'abs_data': { 'shape': int64_array([1, 16, 32, 32]) }, }) ref_graph = build_graph( nodes_attrs=ref_graph_node_attrs_with_4_inputs_interpolate, edges=ref_graph_edges_attrs_with_4_inputs_interpolate, update_attributes={ 'placeholder_data': { 'shape': int64_array([1, 8, 32, 32]) }, 'interpolate_data': { 'shape': int64_array([1, 16, 32, 32]) }, 'abs_data': { 'shape': int64_array([1, 16, 32, 32]) }, 'axes': { 'shape': int64_array([1]), 'value': int64_array([1]) }, }) UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern( graph) (flag, resp) = compare_graphs(graph, ref_graph, 'output') self.assertTrue(flag, resp)
def test_5d(self): graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges) ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs, edges=ref_graph_edges) UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern( graph) (flag, resp) = compare_graphs(graph, ref_graph, 'output') self.assertTrue(flag, resp)
def test_3d(self): graph = build_graph( nodes_attrs=graph_node_attrs_when_transformation_is_not_applicable, edges=graph_edges_when_transformation_is_not_applicable) ref_graph = build_graph( nodes_attrs=graph_node_attrs_when_transformation_is_not_applicable, edges=graph_edges_when_transformation_is_not_applicable) UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern( graph) (flag, resp) = compare_graphs(graph, ref_graph, 'output') self.assertTrue(flag, resp)
def test_2d(self): graph = build_graph( nodes_attrs=graph_node_attrs_when_transformation_is_not_applicable, edges=graph_edges_when_transformation_is_not_applicable, update_attributes={ 'placeholder_data': { 'shape': int64_array([5, 8]) }, 'dim': { 'value': int64_array([1]) }, 'dim_data': { 'value': int64_array([1]) }, 'unsqueeze_data': { 'shape': int64_array([5, 1, 8]) }, 'multipliers': { 'value': int64_array([1, 10, 1]) }, 'multipliers_data': { 'value': int64_array([1, 10, 1]), 'shape': int64_array([3]) }, 'tile_data': { 'shape': int64_array([5, 10, 8]) }, 'reshape_data': { 'shape': int64_array([50, 8]) }, 'shape': { 'value': int64_array([50, 8]), 'shape': int64_array([2]) }, 'shape_data': { 'value': int64_array([50, 8]), 'shape': int64_array([2]) }, 'abs_data': { 'shape': int64_array([50, 8]) }, }) ref_graph = build_graph( nodes_attrs=graph_node_attrs_when_transformation_is_not_applicable, edges=graph_edges_when_transformation_is_not_applicable, update_attributes={ 'placeholder_data': { 'shape': int64_array([5, 8]) }, 'dim': { 'value': int64_array([1]) }, 'dim_data': { 'value': int64_array([1]) }, 'unsqueeze_data': { 'shape': int64_array([5, 1, 8]) }, 'multipliers': { 'value': int64_array([1, 10, 1]) }, 'multipliers_data': { 'value': int64_array([1, 10, 1]), 'shape': int64_array([3]) }, 'tile_data': { 'shape': int64_array([5, 10, 8]) }, 'reshape_data': { 'shape': int64_array([50, 8]) }, 'shape': { 'value': int64_array([50, 8]), 'shape': int64_array([2]) }, 'shape_data': { 'value': int64_array([50, 8]), 'shape': int64_array([2]) }, 'abs_data': { 'shape': int64_array([50, 8]) }, }) UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern( graph) (flag, resp) = compare_graphs(graph, ref_graph, 'output') self.assertTrue(flag, resp)