def test_negative_2(self): graph = build_graph(nodes, edges, {'mul': { 'can_be_fused': False }}, nodes_with_edges_only=True) graph.stage = 'middle' graph_ref = build_graph(nodes, edges, {'mul': { 'can_be_fused': False }}, nodes_with_edges_only=True) MulFakeQuantizeFuse().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_negative_fq_unacceptable_levels(self, levels): nodes = nodes_dict(np.float32, None, levels) graph = build_graph(nodes, [ *connect('weights:0', '0:FQ'), *connect('il:0', '1:FQ'), *connect('ih:0', '2:FQ'), *connect('ol:0', '3:FQ'), *connect('oh:0', '4:FQ'), *connect('FQ:0', 'output'), ], nodes_with_edges_only=True) graph_ref = graph.copy() CompressQuantizeWeights().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_nonreplacement(self): graph = build_graph( nodes_attrs=graph_node_attrs, edges=graph_edges, update_attributes={'roll': { 'input_rank_changed': False }}) graph.stage = 'front' CorrectRollAxes().find_and_replace_pattern(graph) ref_graph = build_graph( nodes_attrs=graph_node_attrs, edges=graph_edges, update_attributes={'roll': { 'input_rank_changed': False }}) (flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_mean_values_explicit_and_scale_values_explicit(self): graph_ref = build_graph(nodes, [ *connect('parameter', '0:add_mean'), *connect('mean', '1:add_mean'), *connect('add_mean', '0:mul_scale'), *connect('scale', '1:mul_scale'), *connect('mul_scale', 'result'), ]) argv = Namespace(mean_scale_values=[[np.array([1., 2., 3.]), np.array([1., 2., 3.])]]) graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv) self.set_graph_attrs(graph, ['parameter']) self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_div_test_2(self): # Test with two same inputs from one placeholder graph = build_graph(nodes, [ *connect('placeholder_1:0', '0:div'), *connect_data('placeholder_1:0', '1:div'), *connect('div', 'output'), ], nodes_with_edges_only=True) Div().find_and_replace_pattern(graph) graph_ref = build_graph(nodes, [ *connect('placeholder_1:0', '0:mul'), *connect_data('placeholder_1:0', '0:reciprocal'), *connect('minus_one', '1:reciprocal'), *connect('reciprocal', '1:mul'), *connect('mul', 'output'), ], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp) self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='Multiply')[0]]['name'] == 'my_div')
def test(self): nodes = { **regular_op('input', {'type': 'Parameter'}), **const('depth', int64_array([2])), **regular_op('onehot', { 'type': 'OneHot', 'kind': 'op', 'op': 'OneHot' }), **regular_op('reshape', { 'type': 'Reshape', 'kind': 'op', 'op': 'Reshape' }), **const('reshape_dims', int64_array([])), **result('result'), } edges = [ ('input', 'onehot'), ('depth', 'onehot'), ('onehot', 'result'), ] graph = build_graph(nodes, edges) graph.graph['layout'] = 'NCHW' graph.stage = 'front' edges_ref = [ ('input', 'onehot'), ('depth', 'reshape'), ('reshape_dims', 'reshape'), ('reshape', 'onehot'), ('onehot', 'result'), ] graph_ref = build_graph(nodes, edges_ref) OneHotDepthNormalizer().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp)
def test_simple_pooling(self): graph = build_graph(self.nodes, [ *connect_front('input', 'splice'), *connect_front('splice', 'pool'), *connect_front('pool', 'out_op') ], nodes_with_edges_only=True) graph.stage = 'front' AddReshapeTransposeAroundConvPool.find_and_replace_pattern(graph) ref_graph = build_graph(self.ref_nodes, [ *connect_front('input', 'splice'), *connect_front('splice', '0:reshape_in'), *connect_front('splice', 'shapeof'), *connect_front('shapeof:0', '0:gather_batch'), *connect_front('ind', '1:gather_batch'), *connect_front('axis', '2:gather_batch'), *connect_front('shapeof:0', '0:gather_h'), *connect_front('ind_h', '1:gather_h'), *connect_front('axis', '2:gather_h'), *connect_front('gather_h', '0:div'), *connect_front('th', '1:div'), *connect_front('gather_batch', '0:concat'), *connect_front('t', '1:concat'), *connect_front('h', '3:concat'), *connect_front('div', '2:concat'), *connect_front('concat', '1:reshape_in'), *connect_front('reshape_in', '0:transpose_in'), *connect_front('transpose_in_order', "1:transpose_in"), *connect_front('transpose_in', 'pool'), *connect_front('pool', '0:transpose_out'), *connect_front('transpose_out_order', '1:transpose_out'), *connect_front('transpose_out', '0:reshape_out'), *connect_front('reshape_out_shape', '1:reshape_out'), *connect_front('reshape_out', 'out_op') ]) (flag, resp) = compare_graphs(graph, ref_graph, 'out_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test_conv_reshape_pool(self): graph = build_graph(self.nodes, [ *connect('conv', '0:transpose_out'), *connect('transpose_out_order', '1:transpose_out'), *connect('transpose_out', '0:reshape_out'), *connect('reshape_out_shape', '1:reshape_out'), *connect('reshape_out', 'shapeof'), *connect('shapeof', '0:gather_batch'), *connect('ind', '1:gather_batch'), *connect('axis', '2:gather_batch'), *connect('shapeof', '0:gather_h', skip_data=True), *connect('ind_h', '1:gather_h'), *connect('axis', '2:gather_h', skip_data=True), *connect('gather_h', '0:div'), *connect('th', '1:div'), *connect('gather_batch', '0:concat'), *connect('t', '1:concat'), *connect('h', '2:concat'), *connect('div', '3:concat'), *connect('concat', '1:reshape_in'), *connect('reshape_out', '0:reshape_in', skip_data=True), *connect('reshape_in', '0:transpose_in'), *connect('transpose_in_order', "1:transpose_in"), *connect('transpose_in', 'pool'), ], nodes_with_edges_only=True) FuseReshapesSequenceKaldi().find_and_replace_pattern(graph) ref_graph = build_graph(self.ref_nodes, [ *connect('conv', '0:transpose_out'), *connect('transpose_out_order', '1:transpose_out'), *connect('transpose_out', '0:transpose_in'), *connect('transpose_in_order', "1:transpose_in"), *connect('transpose_in', 'pool'), ]) (flag, resp) = compare_graphs(graph, ref_graph, 'pool') self.assertTrue(flag, resp)
def test6_not_constant(self): # ,--------------->consumer3 ,->consumer3 # data---(new_shape1)-->consumer1 => data----->consumer1 # `-(new_shape1)-->consumer2 `-->consumer2 # graph = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'eltwise_1'), ('placeholder_1_data', 'eltwise_2'), ('placeholder_1_data', 'eltwise_3'), ('eltwise_1', 'eltwise_1_data'), ('eltwise_2', 'eltwise_2_data'), ('eltwise_3', 'eltwise_3_data'), ('eltwise_1_data', 'concat'), ('eltwise_2_data', 'concat'), ('eltwise_3_data', 'concat'), ], {'placeholder_1_data': {'shape': int64_array([1, 3])}, 'eltwise_1_data': {'shape': int64_array([1, 3])}, 'eltwise_2_data': {'shape': int64_array([1, 3])}, 'eltwise_3_data': {'shape': int64_array([1, 3])}, }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'eltwise_1'), ('placeholder_1_data', 'eltwise_2'), ('placeholder_1_data', 'eltwise_3'), ('eltwise_1', 'eltwise_1_data'), ('eltwise_2', 'eltwise_2_data'), ('eltwise_3', 'eltwise_3_data'), ('eltwise_1_data', 'concat'), ('eltwise_2_data', 'concat'), ('eltwise_3_data', 'concat'), ], {'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True) normalize_eltwise_inputs(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp)
def test_attributed_slice_replacer(self, attributed_slice_attrs): nodes = { **regular_op_with_empty_data('input', {'type': 'Parameter'}), **regular_op_with_empty_data('attributed_slice', attributed_slice_attrs), **result(), # nodes after replacement **const('start', np.array([0, 0])), **const('end', np.array([1, -1])), **const('axis', np.array(np.array([0, 1]))), **regular_op_with_empty_data('slice', { 'op': 'Slice', 'type': None }), } graph = build_graph(nodes_attrs=nodes, edges=[ ('input', 'attributed_slice'), ('attributed_slice', 'output'), ], nodes_with_edges_only=True) graph.stage = 'front' AttributedSliceToSliceReplacer().find_and_replace_pattern(graph) graph_ref = build_graph(nodes_attrs=nodes, edges=[ ('input', 'slice'), *connect_front('start', '1:slice'), *connect_front('end', '2:slice'), *connect_front('axis', '3:slice'), ('slice', 'output'), ], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_instance_normalization_test_1(self): graph = build_graph(nodes_attributes, [('input', 'node'), ('scale', 'node'), ('B', 'node'), ('node', 'out')], { 'node': { 'epsilon': 0.123 }, }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_ref_attributes, [('input', 'mvn', { 'out': 0 }), ('input', 'rank', { 'out': 0 }), ('start', 'mvn_axes'), ('rank', 'mvn_axes'), ('step', 'mvn_axes'), ('mvn_axes', 'mvn'), ('mvn', 'mul'), ('scale', 'mul'), ('mul', 'add'), ('B', 'add'), ('add', 'out')], { 'mvn': { 'eps': 0.123, 'eps_mode': 'inside_sqrt', 'normalize_variance': 1 }, }, nodes_with_edges_only=True) graph.stage = 'front' tested_class = InstanceNormalization() tested_class.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'out', check_op_attrs=False) self.assertTrue(flag, resp)
def test_v10_group_convolution_resolver(self): nodes = { **regular_op_with_shaped_data('input', [1, 3, 224, 224], { 'type': 'Parameter' }), **valued_const_with_data('weights', np.ones([3, 8, 7, 7])), **valued_const_with_data('dim', int64_array([3, 8, 1, 7, 7])), **regular_op_with_empty_data('reshape', {'type': 'Reshape'}), **regular_op_with_shaped_data('convolution', None, { 'type': 'Convolution', 'group': 3, 'output': 24 }), **result(), } graph = build_graph(nodes, [ *connect('input', '0:convolution'), *connect('weights', '1:convolution'), *connect('convolution', 'output'), ], nodes_with_edges_only=True) V10ConvolutionWithGroupsResolver().find_and_replace_pattern(graph) nodes['convolution']['type'] = 'GroupConvolution' del nodes['convolution']['group'] graph_ref = build_graph(nodes, [ *connect('input', '0:convolution'), *connect('weights', '0:reshape'), *connect('dim', '1:reshape'), *connect('reshape', '1:convolution'), *connect('convolution', 'output'), ], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_scaleshift2_axis1_to_mul(self): graph = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_2', 'placeholder_2_data'), ('placeholder_1_data', 'scaleshift_1'), ('placeholder_2_data', 'scaleshift_1'), ('scaleshift_1', 'scaleshift_1_data'), ('scaleshift_1_data', 'op_output') ], {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])}, 'placeholder_2_data': {'shape': np.array([227])}, 'scaleshift_1': {'axis': 1}, 'scaleshift_1_data': {} }) graph_ref = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_2', 'placeholder_2_data'), ('placeholder_2_data', 'placeholder_2/Reshape_'), ('placeholder_2/Reshape_const', 'placeholder_2/Reshape_const_data'), ('placeholder_2/Reshape_const_data', 'placeholder_2/Reshape_'), ('placeholder_2/Reshape_', 'placeholder_2/Reshape_data'), ('placeholder_1_data', 'mul_1'), ('placeholder_2/Reshape_data', 'mul_1'), ('mul_1', 'scaleshift_1_data'), ('scaleshift_1_data', 'op_output') ], {'placeholder_1_data': {'shape': np.array([1, 227, 227, 3])}, 'placeholder_2_data': {'shape': np.array([227])}, 'placeholder_2/Reshape_const': {'value': np.array([1, 227, 1, 1]), 'shape': [4]}, 'placeholder_2/Reshape_const_data': {'value': np.array([1, 227, 1, 1]), 'shape': [4]}, 'placeholder_2/Reshape_data': {'shape': np.array([1, 227, 1, 1])}, 'mul_1': {'can_be_fused': True}, 'scaleshift_1_data': {} }) graph.graph['layout'] = 'NHWC' convert_scale_shift_to_mul_add(graph) graph.clean_up() (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1') self.assertTrue(flag, resp)
def test_set_ports_split2(self): nodes = { **regular_op('op1', {}), **regular_op('split', {'op': 'Split'}), **regular_op('op2', {}), **regular_op('op3', {}), **regular_op('op4', {}), } graph = build_graph(nodes, [('op1', 'split', { 'fw_tensor_debug_info': {} }), ('split', 'op2', { 'fw_tensor_debug_info': {}, 'out_port': 0 }), ('split', 'op4', { 'fw_tensor_debug_info': {}, 'out_port': 4 }), ('split', 'op3', { 'fw_tensor_debug_info': {}, 'out_port': 6 })], nodes_with_edges_only=True) graph.stage = 'front' graph.nodes()['split']['out_ports_count'] = 3 ref_graph = build_graph(nodes, [ *connect_front('op1:0', '0:split'), *connect_front('split:0', '0:op2'), *connect_front('split:1', '0:op4'), *connect_front('split:2', '0:op3') ], nodes_with_edges_only=True) SetPortsPattern().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'op4', check_op_attrs=True) self.assertTrue(flag, resp)
def test_CTCGreedyDecoderSingle_negative(self): edges = [ ('logits', 'decoder', { 'out': 0, 'in': 0 }), ('seq_len', 'decoder', { 'out': 0, 'in': 1 }), ('decoder', 'sparse_to_dense', { 'out': 0, 'in': 0 }), ('decoder', 'cast', { 'out': 1, 'in': 0 }), ('cast', 'sparse_to_dense', { 'out': 0 }), ('sparse_to_dense', 'last', { 'out': 0, 'in': 0 }), ] graph = build_graph(self.nodes_attributes, edges, nodes_with_edges_only=True) graph.stage = 'front' CTCGreedyDecoderSingleReplacement().find_and_replace_pattern(graph) graph_ref = build_graph(self.nodes_attributes, edges, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True) self.assertTrue(flag, resp)
def test_ti_reverse(self): graph = build_graph(nodes, [ *connect('parameter:0', '0:direct_reverse'), *connect('parameter:0', 'shapeof', skip_data=True), *connect('shapeof:0', '0:gather_batch'), *connect('gather_batch_ind', '1:gather_batch'), *connect('gather_axis', '2:gather_batch'), *connect('shapeof:0', '0:gather_seq', skip_data=True), *connect('gather_seq_ind', '1:gather_seq'), *connect('gather_axis', '2:gather_seq'), *connect('gather_seq', '0:broadcast'), *connect('gather_batch', '1:broadcast'), *connect('broadcast', '1:direct_reverse'), *connect('direct_reverse', '0:ti'), *connect('init_hidden', '1:ti'), *connect('ti', 'inverse_shapeof'), *connect('inverse_shapeof:0', '0:inverse_gather_batch'), *connect('gather_batch_ind', '1:inverse_gather_batch'), *connect('gather_axis', '2:inverse_gather_batch'), *connect( 'inverse_shapeof:0', '0:inverse_gather_seq', skip_data=True), *connect('gather_seq_ind', '1:inverse_gather_seq'), *connect('gather_axis', '2:inverse_gather_seq'), *connect('inverse_gather_seq', '0:inverse_broadcast'), *connect('inverse_gather_batch', '1:inverse_broadcast'), *connect('ti', '0:inverse_reverse', skip_data=True), *connect('inverse_broadcast', '1:inverse_reverse'), *connect( 'inverse_reverse', 'some_op'), *connect('some_op', 'output') ], nodes_with_edges_only=True) ReverseTensorIteratorLSTM().find_and_replace_pattern(graph) graph.clean_up() ref_graph = build_graph(ref_nodes, [ *connect('parameter', '0:ti'), *connect('init_hidden', '1:ti'), *connect('ti', 'some_op'), *connect('some_op', 'output') ]) flag, resp = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_image_scaler_test_3(self): graph = build_graph(nodes_attributes, [ ('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'im_scaler'), ('im_scaler', 'im_scaler_data'), ('im_scaler_data', 'last'), ], { 'placeholder_1_data': { 'shape': np.array([1, 227, 227, 3]) }, 'im_scaler': { 'scale': np.array(2.0), 'bias': np.reshape(np.array([0, 0, 0]), [3, 1, 1]) }, }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'mul_1'), ('const_mul_1_w', 'mul_1_w'), ('mul_1_w', 'mul_1'), ('mul_1', 'mul_1_data'), ('mul_1_data', 'last')], { 'placeholder_1_data': { 'shape': np.array([1, 227, 227, 3]) }, 'const_mul_1_w': { 'shape': np.array(2.0).shape, 'value': np.array(2.0) }, }, nodes_with_edges_only=True) graph.graph['layout'] = 'NCHW' graph.stage = 'middle' replacer = ImageScaler() replacer.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'last') self.assertTrue(flag, resp)
def test_pattern_does_not_satisfy(self, input_shape, scales): graph = build_graph( graph_node_attrs, graph_edges, { 'placeholder_data': { 'shape': int64_array(input_shape) }, 'scales': { 'value': int64_array(scales), 'shape': int64_array(scales).shape }, 'scales_data': { 'value': int64_array(scales), 'shape': int64_array(scales).shape }, 'upsample_data': { 'shape': int64_array(input_shape) * int64_array(scales) } }) graph.graph['layout'] = 'NCHW' ref_graph = build_graph( graph_node_attrs, graph_edges, { 'placeholder_data': { 'shape': int64_array(input_shape) }, 'scales': { 'value': int64_array(scales), 'shape': int64_array(scales).shape }, 'scales_data': { 'value': int64_array(scales), 'shape': int64_array(scales).shape }, 'upsample_data': { 'shape': int64_array(input_shape) * int64_array(scales) } }) UpsampleToResample().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, ref_graph, 'output') self.assertTrue(flag, resp)
def test_all_dynamic_inputs(self): nodes = { **regular_op_with_shaped_data('placeholder', [1, 3, 20, 20], { 'type': 'Parameter' }), **regular_op_with_shaped_data('min', [1, 3, 20, 20], { 'type': 'Parameter' }), **regular_op_with_shaped_data('max', [1, 3, 20, 20], { 'type': 'Parameter' }), **regular_op_with_shaped_data('a_clamp', [1, 3, 20, 20], { 'type': None, 'op': 'Clamp' }), **regular_op_with_shaped_data('maximum', [1, 3, 20, 20], { 'type': 'Maximum', 'op': 'Maximum' }), **regular_op_with_shaped_data('minimum', [1, 3, 20, 20], { 'type': 'Minimum', 'op': 'Minimum' }), **result('result'), } edges = [ *connect('placeholder', '0:a_clamp'), *connect('min', '1:a_clamp'), *connect('max', '2:a_clamp'), *connect('a_clamp', 'result'), ] graph = build_graph(nodes, edges) ClampNormalizer().find_and_replace_pattern(graph) ref_graph = build_graph(nodes, [ *connect('placeholder', '0:maximum'), *connect('min', '1:maximum'), *connect('maximum', '0:minimum'), *connect('max', '1:minimum'), *connect('minimum', 'result') ]) (flag, resp) = compare_graphs(graph, ref_graph, 'result') self.assertTrue(flag, resp)
def test_positive(self, input_shape, axes, layout): graph = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) graph.stage = 'middle' graph.graph['layout'] = layout L2NormToNorm().find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('weights_node_data', dict(kind='data', value=axes.sort())), ], edges_after_replacement, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name') self.assertTrue(flag, resp)
def test_negative(self, input_shape, axes, layout): graph = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) graph.stage = 'middle' graph.graph['layout'] = layout L2NormToNorm().find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp)
def test_mean_values_with_data_name(self): graph_ref = build_graph(nodes, [ *connect('parameter', '0:add_mean'), *connect('mean', '1:add_mean'), *connect('add_mean', 'result'), ]) mean_values = parse_tuple_pairs('(1,2,3)') scale_values = parse_tuple_pairs('') mean_scale = get_mean_scale_dictionary(mean_values, scale_values, None) argv = Namespace(mean_scale_values=mean_scale) graph = build_graph(nodes, [*connect('parameter', 'result')], nodes_with_edges_only=True, cli=argv) self.set_graph_attrs(graph, ['parameter']) self.set_graph_attrs(graph_ref, ['parameter']) graph.graph['layout'] = 'NCHW' AddMeanScaleValues().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp) self.check_graph_attrs(graph, graph_ref, ['parameter'])
def test_insert_old_api_map(self): graph = build_graph(get_nodes([1, 10, 10, 3]), [*connect('placeholder1', '0:mul'), *connect('placeholder2', '1:mul'), *connect('mul', 'result')], nodes_with_edges_only=True, cli=Namespace(reverse_input_channels=True)) node = Node(graph, 'placeholder1') old_api_map = OldAPIMapOrder(version=0) node.rt_info.info[('old_api_map_order', old_api_map.get_version())] = old_api_map node.rt_info.info[('old_api_map_order', old_api_map.get_version())].old_api_transpose_parameter([0, 2, 3, 1]) InsertReverseChannels().find_and_replace_pattern(graph) graph_ref = build_graph(get_nodes([1, 10, 10, 3], 3), [*connect('placeholder1', 'reverse_channels'), *connect('reverse_channels', '0:mul'), *connect('placeholder2', '1:mul'), *connect('mul', 'result')]) node2 = Node(graph_ref, 'placeholder1') node2.rt_info = node.rt_info (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp)
def test_identityN_unused_ports(self): graph = build_graph(nodes, [ *connect('placeholder_0', '0:identityN'), *connect('placeholder_1', '1:identityN'), *connect('identityN:0', 'output0'), ], nodes_with_edges_only=True) IdentityN_to_Identity().find_and_replace_pattern(graph) graph_ref = build_graph(nodes, [ *connect('placeholder_0', 'identity0'), *connect('identity0', 'output0'), ], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output0', check_op_attrs=True) self.assertTrue(flag, resp)
def test_3(self): graph = build_graph(nodes_attributes, edges, update_attributes={ 'input_data': { 'shape': int64_array([1, 224, 224, 3]) }, 'shape_like_input_data': { 'shape': int64_array([2, 2, 2, 2, 2]) }, 'slice_like': { 'axes': int64_array([1, 2]) } }, nodes_with_edges_only=True) SliceLikeToStridedSlice().find_and_replace_pattern(graph) ref_graph = build_graph(nodes_attributes, input_part_shape_edges, nodes_with_edges_only=True) flag, resp = compare_graphs(graph, ref_graph, 'result') self.assertTrue(flag, resp)
def test_ScatterElementsUpdate_has_axis_and_3_inputs(self): graph = build_graph(nodes, edges, {'node': { 'axis': 1 }}, nodes_with_edges_only=True) ScatterNormalizer().find_and_replace_pattern(graph) graph_ref = build_graph(nodes, [ *edges, *connect('axis', '3:node'), ], {'axis': { 'value': np.int64(1) }}, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_gelu_p3(self): edges = [('input', 'mul'), ('div', 'erf'), ('erf', 'add'), ('add', 'mul'), ('mul', 'mul0'), ('mul_param', 'mul'), ('div_param', 'div'), ('add_param', 'add'), ('mul0', 'result')] graph = build_graph(self.nodes, edges) graph_ref = build_graph(ref_nodes, ref_edges) graph.stage = 'front' GeLUMergerErf().find_and_replace_pattern(graph) graph.clean_up() (flag, resp) = compare_graphs(graph, graph_ref, 'result') self.assertTrue(flag, resp) self.assertTrue( graph.get_op_nodes(op='Gelu')[0].approximation_mode == 'erf') self.assertTrue( len(graph.get_op_nodes(name='final_mul')) == 1 and graph.get_op_nodes(name='final_mul')[0].op == 'Gelu')
def test_case6_dest(self): graph = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op1'), ('Op1', 'Op1_data'), ('Op1_data', 'Op2')]) graph_ref = build_graph(nodes, [('input', 'input_data'), ('input_data', 'Op2'), ('Op1', 'Op1_data')]) input_data = Node(graph_ref, 'input_data') input_data['fw_tensor_debug_info'] = [('Op1', 'Op1')] op1_node = Node(graph, 'Op1') op1_node.in_port(0).get_connection().set_destination( op1_node.out_port(0).get_destination(), "dest") (flag, resp) = compare_graphs(graph, graph_ref, 'Op2', check_op_attrs=True) self.assertTrue(flag, resp) self.check_graph_attrs_middle(graph, graph_ref)
def test_ifft_replacement(self, input_shape): graph = build_graph(nodes_attrs=fft_graph_node_attrs, edges=fft_graph_edges, update_attributes={ 'placeholder': { 'shape': input_shape }, 'fft': { 'is_inverse': True } }) graph.stage = 'front' MXFFTToDFT().find_and_replace_pattern(graph) ref_graph = build_graph( nodes_attrs=ref_converted_ifft_graph_node_attrs, edges=ref_converted_ifft_graph_edges, update_attributes={'placeholder': { 'shape': input_shape }}) (flag, resp) = compare_graphs(graph, ref_graph, 'output') self.assertTrue(flag, resp)
def test_pool_v2_to_attributed_pool(self): nodes = { **shaped_const_with_data('input', int64_array([200, 200])), **valued_const_with_data('windows', int64_array([4, 4])), **valued_const_with_data('strides', int64_array([4, 4])), **regular_op_with_empty_data('pool_v2', {'op': 'PoolingV2', 'pad': [2, 2], 'spatial_dims': [1, 2], 'auto_pad': 'same_upper', 'output_spatial_shape': [2, 3], 'pad_spatial_shape': [1, 2], 'pool_method': 'max', 'permute_attrs': None}), **regular_op_with_empty_data('pool_v1', {'type': 'Pooling', 'pad': [2, 2], 'spatial_dims': [1, 2], 'auto_pad': 'same_upper', 'output_spatial_shape': [2, 3], 'pad_spatial_shape': [1, 2], 'pool_method': 'max'}), **result('output') } edges = [ *connect('input', 'pool_v2:0'), *connect('windows', 'pool_v2:1'), *connect('strides', 'pool_v2:2'), *connect('pool_v2', 'output'), ] graph = build_graph(nodes, edges, nodes_with_edges_only=True) PoolV2ToAttributedPool().find_and_replace_pattern(graph) ref_graph = build_graph(nodes, [*connect('input', 'pool_v1'), *connect('pool_v1', 'output')], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, ref_graph, 'output') self.assertTrue(flag, resp)