示例#1
0
    def test_multi(self):
        nodes = {
            **regular_op_with_empty_data('input', {'type': 'Parameter'}),
            **regular_op_with_empty_data('some_op', {'type': 'SomeOp', 'name': 'some_op_name'}),
            **empty_data('some_op_d2'),
            **regular_op_with_empty_data('fake_output1',
                                         {'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name1'}),
            **regular_op_with_empty_data('fake_output2',
                                         {'type': None, 'kind': 'op', 'op': 'FakeOutput', 'name': 'my_output_name2'}),

            **const_with_data('const1', int64_array(0)),
            **const_with_data('const2', int64_array(0)),
            **regular_op_with_empty_data('add1', {'type': None, 'kind': 'op', 'op': 'Add', 'name': 'my_output_name1'}),
            **regular_op_with_empty_data('add2', {'type': None, 'kind': 'op', 'op': 'Add', 'name': 'my_output_name2'}),
            **result('result1'),
            **result('result2'),
        }
        edges = [*connect('input', 'some_op'),
                 *connect('some_op', 'fake_output1'),
                 ('some_op', 'some_op_d2'),
                 ('some_op_d2', 'fake_output2'),
                 *connect('fake_output1', 'result1'),
                 *connect('fake_output2', 'result2'),
                 ]
        graph = build_graph(nodes, edges)

        edges_ref = [*connect('input', 'some_op'),
                     *connect('some_op', '0:add1'),
                     *connect('const1', '1:add1'),
                     ('some_op', 'some_op_d2'),
                     ('some_op_d2', 'add2', {'in': 0}),
                     *connect('const2', '1:add2'),
                     *connect('add1', 'result1'),
                     *connect('add2', 'result2'),
                     ]

        graph_ref = build_graph(nodes, edges_ref)

        FakeOutputResolver().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result1')
        self.assertTrue(flag, resp)
 def test_convert_slice_to_strided_slice(self, input_shape, start, end,
                                         axes, steps, ss_begin_parts: tuple,
                                         ss_end_parts: tuple, ss_steps,
                                         ss_begin_mask, ss_end_mask):
     graph = build_graph(
         nodes_attrs={
             **regular_op_with_shaped_data('input', input_shape, {
                 'type': 'Parameter'
             }),
             **valued_const_with_data('start', start),
             **valued_const_with_data('end', end),
             **valued_const_with_data('axes', axes),
             **valued_const_with_data('steps', steps),
             **regular_op_with_empty_data('slice', {
                 'type': None,
                 'op': 'Slice'
             }),
             **result('result')
         },
         edges=[
             *connect('input', 'slice'), *connect('start', '1:slice'),
             *connect('end', '2:slice'), *connect('axes', '3:slice'),
             *connect('steps', '4:slice'), *connect('slice', 'result')
         ])
     ref_graph = build_graph(nodes_attrs={
         **regular_op_with_shaped_data('input', input_shape, {
             'type': 'Parameter'
         }),
         **valued_const_with_data('start', start),
         **valued_const_with_data('begin_first_part', ss_begin_parts[0]),
         **valued_const_with_data('begin_last_part', ss_begin_parts[1]),
         **regular_op_with_empty_data('convert_start', {
             'op': 'Cast',
             'type': 'Convert',
             'dst_type': np.int64
         }),
         **regular_op_with_empty_data('ss_begin', {
             'type': 'Concat',
             'op': 'Concat',
             'axis': 0
         }),
         **valued_const_with_data('end', end),
         **valued_const_with_data('end_first_part', ss_end_parts[0]),
         **valued_const_with_data('end_last_part', ss_end_parts[1]),
         **regular_op_with_empty_data('convert_end', {
             'op': 'Cast',
             'type': 'Convert',
             'dst_type': np.int64
         }),
         **regular_op_with_empty_data('ss_end', {
             'type': 'Concat',
             'op': 'Concat',
             'axis': 0
         }),
         **const('ss_steps', ss_steps),
         **empty_data('ss_steps_d'),
         **regular_op_with_empty_data(
             'ss', {
                 'op': 'StridedSlice',
                 'type': 'StridedSlice',
                 'begin_mask': ss_begin_mask,
                 'end_mask': ss_end_mask,
                 'new_axis_mask': np.zeros(len(input_shape), dtype=np.int64),
                 'shrink_axis_mask': np.zeros(len(input_shape),
                                              dtype=np.int64),
                 'ellipsis_mask': np.zeros(len(input_shape), dtype=np.int64)
             }),
         **result('result')
     },
                             edges=[
                                 *connect('input', 'ss'),
                                 *connect('begin_first_part', 'ss_begin'),
                                 *connect('start', 'convert_start'),
                                 *connect('convert_start', '1:ss_begin'),
                                 *connect('begin_last_part', '2:ss_begin'),
                                 *connect('ss_begin', '1:ss'),
                                 *connect('end_first_part', 'ss_end'),
                                 *connect('end', 'convert_end'),
                                 *connect('convert_end', '1:ss_end'),
                                 *connect('end_last_part', '2:ss_end'),
                                 *connect('ss_end', '2:ss'),
                                 *connect('ss_steps', '3:ss'),
                                 *connect('ss', 'result')
                             ])
     ConvertSlice().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph,
                                   ref_graph,
                                   'result',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)
示例#3
0
nodes = {
    **regular_op_with_shaped_data('placeholder_0', [1, 227, 227, 3], {
                                      'type': 'Parameter'
                                  }),
    **regular_op_with_shaped_data('placeholder_1', [1, 227, 227, 3], {
                                      'type': 'Parameter'
                                  }),
    **regular_op_with_empty_data(
        'identityN', {
            'op': 'IdentityN',
            'type': None,
            'data_types': [np.int32, np.float],
            'name': 'my_identity'
        }),
    **empty_data('identityN_1_d'),
    **regular_op_with_empty_data(
        'identity0', {
            'op': 'Identity',
            'type': None,
            'data_type': np.int32,
            'name': 'my_identity/0_port'
        }),
    **regular_op_with_empty_data(
        'identity1', {
            'op': 'Identity',
            'type': None,
            'data_type': np.float,
            'name': 'my_identity/1_port'
        }),
    **result('output0'),
 def test_convert_slice_to_strided_slice_without_axes_and_steps(self):
     graph = build_graph(nodes_attrs={
         **regular_op_with_shaped_data('input', int64_array([2, 5, 10]), {
                                           'type': 'Parameter'
                                       }),
         **valued_const_with_data('start', np.array([0, 0, 0])),
         **valued_const_with_data('end', np.array([1, 3, 5])),
         **regular_op_with_empty_data('slice', {
             'type': None,
             'op': 'Slice'
         }),
         **result('result')
     },
                         edges=[
                             *connect('input', 'slice'),
                             *connect('start', '1:slice'),
                             *connect('end', '2:slice'),
                             *connect('slice', 'result')
                         ])
     ref_graph = build_graph(nodes_attrs={
         **regular_op_with_shaped_data('input', int64_array([2, 5, 10]), {
                                           'type': 'Parameter'
                                       }),
         **valued_const_with_data('start', np.array([0, 0, 0])),
         **valued_const_with_data('begin_first_part', int64_array([])),
         **valued_const_with_data('begin_last_part', int64_array([])),
         **regular_op_with_empty_data('convert_start', {
             'op': 'Cast',
             'type': 'Convert',
             'dst_type': np.int64
         }),
         **regular_op_with_empty_data('ss_begin', {
             'type': 'Concat',
             'op': 'Concat',
             'axis': 0
         }),
         **valued_const_with_data('end', np.array([1, 3, 5])),
         **valued_const_with_data('end_first_part', int64_array([])),
         **valued_const_with_data('end_last_part', int64_array([])),
         **regular_op_with_empty_data('convert_end', {
             'op': 'Cast',
             'type': 'Convert',
             'dst_type': np.int64
         }),
         **regular_op_with_empty_data('ss_end', {
             'type': 'Concat',
             'op': 'Concat',
             'axis': 0
         }),
         **const('ss_steps', int64_array([1, 1, 1])),
         **empty_data('ss_steps_d'),
         **regular_op_with_empty_data(
             'ss', {
                 'op': 'StridedSlice',
                 'type': 'StridedSlice',
                 'begin_mask': int64_array([1, 1, 1]),
                 'end_mask': int64_array([1, 1, 1]),
                 'new_axis_mask': np.zeros(3, dtype=np.int64),
                 'shrink_axis_mask': np.zeros(3, dtype=np.int64),
                 'ellipsis_mask': np.zeros(3, dtype=np.int64)
             }),
         **result('result')
     },
                             edges=[
                                 *connect('input', 'ss'),
                                 *connect('begin_first_part', 'ss_begin'),
                                 *connect('start', 'convert_start'),
                                 *connect('convert_start', '1:ss_begin'),
                                 *connect('begin_last_part', '2:ss_begin'),
                                 *connect('ss_begin', '1:ss'),
                                 *connect('end_first_part', 'ss_end'),
                                 *connect('end', 'convert_end'),
                                 *connect('convert_end', '1:ss_end'),
                                 *connect('end_last_part', '2:ss_end'),
                                 *connect('ss_end', '2:ss'),
                                 *connect('ss_steps', '3:ss'),
                                 *connect('ss', 'result')
                             ])
     ConvertSlice().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph,
                                   ref_graph,
                                   'result',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)