def test_multi(self): nodes = { **regular_op_with_empty_data('input', {'type': 'Parameter'}), **regular_op_with_empty_data('some_op', {'type': 'SomeOp', 'name': 'some_op_name'}), **empty_data('some_op_d2'), **regular_op_with_empty_data('fake_output1', {'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name1'}), **regular_op_with_empty_data('fake_output2', {'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name2'}), **const_with_data('const1', int64_array(0)), **const_with_data('const2', int64_array(0)), **regular_op_with_empty_data('add1', {'type': None, 'kind': 'op', 'op': 'Add', 'name': 'my_output_name1'}), **regular_op_with_empty_data('add2', {'type': None, 'kind': 'op', 'op': 'Add', 'name': 'my_output_name2'}), **result('result1'), **result('result2'), } edges = [*connect('input', 'some_op'), *connect('some_op', 'fake_output1'), ('some_op', 'some_op_d2'), ('some_op_d2', 'fake_output2'), *connect('fake_output1', 'result1'), *connect('fake_output2', 'result2'), ] graph = build_graph(nodes, edges) edges_ref = [*connect('input', 'some_op'), *connect('some_op', '0:add1'), *connect('const1', '1:add1'), ('some_op', 'some_op_d2'), ('some_op_d2', 'add2', {'in': 0}), *connect('const2', '1:add2'), *connect('add1', 'result1'), *connect('add2', 'result2'), ] graph_ref = build_graph(nodes, edges_ref) FakeOutputResolver().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result1') self.assertTrue(flag, resp)
def test_convert_slice_to_strided_slice(self, input_shape, start, end, axes, steps, ss_begin_parts: tuple, ss_end_parts: tuple, ss_steps, ss_begin_mask, ss_end_mask): graph = build_graph( nodes_attrs={ **regular_op_with_shaped_data('input', input_shape, { 'type': 'Parameter' }), **valued_const_with_data('start', start), **valued_const_with_data('end', end), **valued_const_with_data('axes', axes), **valued_const_with_data('steps', steps), **regular_op_with_empty_data('slice', { 'type': None, 'op': 'Slice' }), **result('result') }, edges=[ *connect('input', 'slice'), *connect('start', '1:slice'), *connect('end', '2:slice'), *connect('axes', '3:slice'), *connect('steps', '4:slice'), *connect('slice', 'result') ]) ref_graph = build_graph(nodes_attrs={ **regular_op_with_shaped_data('input', input_shape, { 'type': 'Parameter' }), **valued_const_with_data('start', start), **valued_const_with_data('begin_first_part', ss_begin_parts[0]), **valued_const_with_data('begin_last_part', ss_begin_parts[1]), **regular_op_with_empty_data('convert_start', { 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64 }), **regular_op_with_empty_data('ss_begin', { 'type': 'Concat', 'op': 'Concat', 'axis': 0 }), **valued_const_with_data('end', end), **valued_const_with_data('end_first_part', ss_end_parts[0]), **valued_const_with_data('end_last_part', ss_end_parts[1]), **regular_op_with_empty_data('convert_end', { 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64 }), **regular_op_with_empty_data('ss_end', { 'type': 'Concat', 'op': 'Concat', 'axis': 0 }), **const('ss_steps', ss_steps), **empty_data('ss_steps_d'), **regular_op_with_empty_data( 'ss', { 'op': 'StridedSlice', 'type': 'StridedSlice', 'begin_mask': ss_begin_mask, 'end_mask': ss_end_mask, 'new_axis_mask': np.zeros(len(input_shape), dtype=np.int64), 'shrink_axis_mask': np.zeros(len(input_shape), dtype=np.int64), 'ellipsis_mask': np.zeros(len(input_shape), dtype=np.int64) }), **result('result') }, edges=[ *connect('input', 'ss'), *connect('begin_first_part', 'ss_begin'), *connect('start', 'convert_start'), *connect('convert_start', '1:ss_begin'), *connect('begin_last_part', '2:ss_begin'), *connect('ss_begin', '1:ss'), *connect('end_first_part', 'ss_end'), *connect('end', 'convert_end'), *connect('convert_end', '1:ss_end'), *connect('end_last_part', '2:ss_end'), *connect('ss_end', '2:ss'), *connect('ss_steps', '3:ss'), *connect('ss', 'result') ]) ConvertSlice().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True) self.assertTrue(flag, resp)
nodes = { **regular_op_with_shaped_data('placeholder_0', [1, 227, 227, 3], { 'type': 'Parameter' }), **regular_op_with_shaped_data('placeholder_1', [1, 227, 227, 3], { 'type': 'Parameter' }), **regular_op_with_empty_data( 'identityN', { 'op': 'IdentityN', 'type': None, 'data_types': [np.int32, np.float], 'name': 'my_identity' }), **empty_data('identityN_1_d'), **regular_op_with_empty_data( 'identity0', { 'op': 'Identity', 'type': None, 'data_type': np.int32, 'name': 'my_identity/0_port' }), **regular_op_with_empty_data( 'identity1', { 'op': 'Identity', 'type': None, 'data_type': np.float, 'name': 'my_identity/1_port' }), **result('output0'),
def test_convert_slice_to_strided_slice_without_axes_and_steps(self): graph = build_graph(nodes_attrs={ **regular_op_with_shaped_data('input', int64_array([2, 5, 10]), { 'type': 'Parameter' }), **valued_const_with_data('start', np.array([0, 0, 0])), **valued_const_with_data('end', np.array([1, 3, 5])), **regular_op_with_empty_data('slice', { 'type': None, 'op': 'Slice' }), **result('result') }, edges=[ *connect('input', 'slice'), *connect('start', '1:slice'), *connect('end', '2:slice'), *connect('slice', 'result') ]) ref_graph = build_graph(nodes_attrs={ **regular_op_with_shaped_data('input', int64_array([2, 5, 10]), { 'type': 'Parameter' }), **valued_const_with_data('start', np.array([0, 0, 0])), **valued_const_with_data('begin_first_part', int64_array([])), **valued_const_with_data('begin_last_part', int64_array([])), **regular_op_with_empty_data('convert_start', { 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64 }), **regular_op_with_empty_data('ss_begin', { 'type': 'Concat', 'op': 'Concat', 'axis': 0 }), **valued_const_with_data('end', np.array([1, 3, 5])), **valued_const_with_data('end_first_part', int64_array([])), **valued_const_with_data('end_last_part', int64_array([])), **regular_op_with_empty_data('convert_end', { 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64 }), **regular_op_with_empty_data('ss_end', { 'type': 'Concat', 'op': 'Concat', 'axis': 0 }), **const('ss_steps', int64_array([1, 1, 1])), **empty_data('ss_steps_d'), **regular_op_with_empty_data( 'ss', { 'op': 'StridedSlice', 'type': 'StridedSlice', 'begin_mask': int64_array([1, 1, 1]), 'end_mask': int64_array([1, 1, 1]), 'new_axis_mask': np.zeros(3, dtype=np.int64), 'shrink_axis_mask': np.zeros(3, dtype=np.int64), 'ellipsis_mask': np.zeros(3, dtype=np.int64) }), **result('result') }, edges=[ *connect('input', 'ss'), *connect('begin_first_part', 'ss_begin'), *connect('start', 'convert_start'), *connect('convert_start', '1:ss_begin'), *connect('begin_last_part', '2:ss_begin'), *connect('ss_begin', '1:ss'), *connect('end_first_part', 'ss_end'), *connect('end', 'convert_end'), *connect('convert_end', '1:ss_end'), *connect('end_last_part', '2:ss_end'), *connect('ss_end', '2:ss'), *connect('ss_steps', '3:ss'), *connect('ss', 'result') ]) ConvertSlice().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True) self.assertTrue(flag, resp)