def test_1(self): """ Testing case with non-constant path and multiple slicing dimensions :return: """ graph = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'slice'), ('slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data'), ('output_data', 'op_output')], { 'placeholder_1_data': { 'shape': np.array([4, 5, 6]) }, 'slice': { 'start': np.array([1, 2, 3]), 'end': np.array([3, 4, 4]), 'axis': None }, }, nodes_with_edges_only=True, ) slice_node = Node(graph, 'slice') Slice.infer(slice_node) pattern = ConvertSlice() pattern.find_and_replace_pattern(graph) graph.clean_up() ss_node = Node(graph, graph.get_node_id_by_name('slice_node')) assert ss_node.type == 'Crop', 'Something wrong with transformed Slice node' graph_ref = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'crop'), ('crop', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data'), ('output_data', 'op_output')], { 'placeholder_1_data': { 'shape': np.array([4, 5, 6]) }, 'crop': { 'axis': np.array([0, 1, 2]), 'offset': np.array([1, 2, 3]), 'dim': np.array([2, 2, 1]) }, }, nodes_with_edges_only=True, ) (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test_no_steps_no_axes(self): input_shape = int64_array([5, 10, 20]) starts_value = int64_array([3, 2, 7]) ends_value = int64_array([5, 8, 15]) steps_value = int64_array([1, 1, 1]) masks_value = np.zeros([len(input_shape)], dtype=np.int64) graph = build_graph(self.nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'slice', {'in': 0}), ('starts', 'starts_data'), ('starts_data', 'slice', {'in': 1}), ('ends', 'ends_data'), ('ends_data', 'slice', {'in': 2}), ('slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data'), ('output_data', 'op_output') ], {'placeholder_1_data': {'shape': input_shape}, 'starts': {'shape': starts_value.shape, 'value': starts_value}, 'starts_data': {'shape': starts_value.shape, 'value': starts_value}, 'ends': {'shape': ends_value.shape, 'value': ends_value}, 'ends_data': {'shape': ends_value.shape, 'value': ends_value}, }, nodes_with_edges_only=True ) slice_node = Node(graph, 'slice') Slice.infer(slice_node) pattern = ConvertSlice() pattern.find_and_replace_pattern(graph) ss_node = Node(graph, graph.get_node_id_by_name('slice_node')) assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node' graph_ref = build_graph(self.nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'strided_slice', {'in': 0}), ('starts', 'starts_data'), ('starts_data', 'strided_slice', {'in': 1}), ('ends', 'ends_data'), ('ends_data', 'strided_slice', {'in': 2}), ('strides', 'strides_data'), ('strides_data', 'strided_slice', {'in': 3}), ('strided_slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data'), ('output_data', 'op_output') ], {'placeholder_1_data': {'shape': input_shape}, 'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value, 'ellipsis_mask': masks_value, 'begin_mask': np.ones([3]), 'end_mask': np.ones([3])}, 'slice_data': {'shape': int64_array([2, 6, 8])} }, nodes_with_edges_only=True ) (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test_2(self): """ Testing case with constant path and one slicing dimension """ graph = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'slice'), ('slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data')], { 'placeholder_1_data': { 'shape': np.array([4, 5, 6]) }, 'slice': { 'start': np.array([1]), 'end': np.array([3]), 'axis': None }, 'output_op': { 'is_output': True } }) slice_node = Node(graph, 'slice') Slice.infer(slice_node) pattern = ConvertSlice() pattern.find_and_replace_pattern(graph) graph_ref = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'strided_slice'), ('strided_slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data')], { 'placeholder_1_data': { 'shape': np.array([4, 5, 6]) }, 'strided_slice': { 'slices': np.array([slice(1, 3, 1), slice(0, 5, 1), slice(0, 6, 1)]), 'shrink_axis_mask': np.array([False, False, False]) }, 'output_op': { 'is_output': True } }) (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test_1(self): """ Testing case with non-constant path and multiple slicing dimensions :return: """ graph = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'slice'), ('slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data')], { 'placeholder_1_data': { 'shape': np.array([4, 5, 6]) }, 'slice': { 'start': np.array([1, 2, 3]), 'end': np.array([3, 4, 4]), 'axis': None }, 'output_op': { 'is_output': True }, }) slice_node = Node(graph, 'slice') Slice.infer(slice_node) pattern = ConvertSlice() pattern.find_and_replace_pattern(graph) graph_ref = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'crop'), ('crop', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data')], { 'placeholder_1_data': { 'shape': np.array([4, 5, 6]) }, 'crop': { 'axis': np.array([0, 1, 2]), 'offset': np.array([1, 2, 3]), }, 'output_op': { 'is_output': True }, 'dim': { 'dim': np.array([2, 2, 1]) }, }) (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test_3(self): """ Testing case with constant path and one slicing dimension """ graph = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'slice'), ('slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data'), ('output_data', 'op_output')], { 'placeholder_1_data': { 'shape': np.array([1, 5, 6]) }, 'slice': { 'start': np.array([1]), 'end': np.array([3]), 'axis': np.array([1]) } }, nodes_with_edges_only=True, ) graph.graph['layout'] = 'NHWC' slice_node = Node(graph, 'slice') Slice.infer(slice_node) pattern = ConvertSlice() pattern.find_and_replace_pattern(graph) graph.clean_up() ss_node = Node(graph, graph.get_node_id_by_name('slice_node')) assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node' graph_ref = build_graph( nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_2', 'placeholder_2_data'), ('placeholder_3', 'placeholder_3_data'), ('placeholder_1_data', 'strided_slice'), ('placeholder_2_data', 'strided_slice'), ('placeholder_3_data', 'strided_slice'), ('strided_slice', 'slice_data'), ('slice_data', 'output_op'), ('output_op', 'output_data'), ('output_data', 'op_output')], { 'placeholder_1_data': { 'shape': np.array([1, 5, 6]) }, 'strided_slice': { 'slices': np.array([slice(0, 1, 1), slice(1, 3, 1), slice(0, 6, 1)]), 'shrink_axis_mask': np.array([False, False, False]) }, }, nodes_with_edges_only=True, ) (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) self.assertTrue(flag, resp)