def test_two_nodes_one_bin(self): """Test case for two output nodes, one with 'bin' parameter, other without.""" shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes + [('next_node_2', { 'kind': 'op' })], edges_with_attrs=self.edges + [('data_node', 'next_node_2')], update_nodes_attributes=[('data_node', { 'shape': shape, 'value': data })], update_edge_attrs={('data_node', 'next_node', 0): { 'bin': 0 }}, ) graph_ref = build_graph_with_attrs( nodes_with_attrs=self.nodes + self.new_nodes + [('next_node_2', { 'kind': 'op' })], edges_with_attrs=self.edges + self.new_edges + [('data_node', 'next_node_2')], update_nodes_attributes=[('data_node', { 'shape': shape, 'value': data }), ('const_data', { 'shape': shape, 'value': data })]) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node') self.assertTrue(flag, resp)
def test_select_infer_condition_true(self): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[ ('condition', { 'value': np.array([True]) }), ('select_output', { 'shape': np.array([2, 2]), 'value': np.ones((2, 2)) }) ]) # We should propagate shapes and values graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[ ('select_output', { 'shape': np.array([2, 2]), 'value': np.ones((2, 2)) }) ]) tested_class = Select(graph=graph, attrs={}) node = Node(graph, 'select') tested_class.infer(node) (flag, resp) = compare_graphs(graph, graph_ref, 'select_output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_select_infer_condition_with_value(self, else_data_shape, than_data_shape, select_output_shape, condition_value, else_value, than_value, output_value): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('condition_data', {'shape': np.array(select_output_shape), 'value': condition_value}), ('else_data', {'shape': np.array(else_data_shape), 'value': else_value}), ('than_data', {'shape': np.array(than_data_shape), 'value': than_value}), ('select_output', {'shape': np.array(select_output_shape), 'value': None}) ]) graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[ ('condition_data', {'shape': np.array(else_data_shape), 'value': condition_value}), ('else_data', {'shape': np.array(else_data_shape), 'value': else_value}), ('than_data', {'shape': np.array(than_data_shape), 'value': than_value}), ('select_output', {'shape': np.array(select_output_shape), 'value': output_value})]) node = Node(graph, 'select') Select.infer(node) if else_value is not None and than_value is not None: (flag, resp) = compare_graphs(graph, graph_ref, 'select_output', check_op_attrs=True) self.assertTrue(flag, resp) self.assertTrue(np.array_equal(graph.nodes['select_output']['value'], graph_ref.nodes['select_output']['value']))
def test_two_consumers(self): """Const data node has two consumers: Result and ReLu""" nodes = [ ('const_node', { 'type': 'Const', 'kind': 'op' }), ('const_data', { 'kind': 'data' }), ('result_node', { 'type': 'Result', 'kind': 'op' }), ('relu_1', { 'type': 'ReLU', 'kind': 'op', 'op': 'ReLU' }), ('relu_1_data', { 'kind': 'data' }), ] edges = [('const_node', 'const_data'), ('const_data', 'result_node'), ('const_data', 'relu_1'), ('relu_1', 'relu_1_data')] new_nodes = [ ('const_node', { 'type': 'Const', 'kind': 'op' }), ('const_data', { 'kind': 'data' }), ('relu_1', { 'type': 'ReLU', 'kind': 'op', 'op': 'ReLU' }), ('relu_1_data', { 'kind': 'data' }), ] new_edges = [('const_node', 'const_data'), ('const_data', 'relu_1'), ('relu_1', 'relu_1_data')] graph = build_graph_with_attrs( nodes_with_attrs=nodes, edges_with_attrs=edges, ) graph_ref = build_graph_with_attrs( nodes_with_attrs=new_nodes, edges_with_attrs=new_edges, ) tested_pattern = RemoveConstToResult() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data') self.assertTrue(flag, resp) self.assertNotIn('result_node', graph.node)
def test_no_exit(self): pattern_matcher = BackEdgesMatching() pattern = pattern_matcher.pattern() graph = build_graph_with_attrs(nodes_with_attrs=pattern['nodes'], edges_with_attrs=pattern['edges'], update_edge_attrs=None, new_nodes_with_attrs=[('from_body_data', {'kind':'data'})], new_edges_with_attrs=[('from_body_data', 'NextIteration')]) pattern_matcher.find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes_with_attrs=[('condition', {'kind': 'op', 'op':'TensorIteratorCondition'}), ('condition_data', {'kind': 'data'}), ('back_edge', {'kind': 'op', 'op': 'TensorIteratorBackEdge'}), ('enter_data', {'kind': 'data'}), ('from_body_data', {'kind': 'data'}), ('Identity_1_data', {'kind': 'data'}),], edges_with_attrs=[('condition', 'condition_data'), ('enter_data', 'back_edge', {'in': 0}), ('condition_data', 'back_edge', {'in': 2}), # {in:2} ('from_body_data', 'back_edge', {'in': 1}), ('back_edge', 'Identity_1_data')], update_edge_attrs=None, new_nodes_with_attrs=[], new_edges_with_attrs=[], ) (flag, resp) = compare_graphs(graph, graph_ref, 'Identity_1_data', check_op_attrs=True) self.assertTrue(flag, resp)
def test_select_infer_condition_shapes_broadcast(self, else_data_shape, than_data_shape, select_output_shape): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('else_data', {'shape': np.array(else_data_shape), 'value': np.zeros(else_data_shape, dtype=np.float)}), ('than_data', {'shape': np.array(than_data_shape), 'value': np.zeros(than_data_shape, dtype=np.float)}), ('select_output', {'shape': np.array(select_output_shape), 'value': np.zeros(select_output_shape, dtype=np.float)}) ]) # We should propagate shapes and values graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[ ('else_data', {'shape': np.array(else_data_shape), 'value': np.zeros(else_data_shape, dtype=np.float)}), ('than_data', {'shape': np.array(than_data_shape), 'value': np.zeros(than_data_shape, dtype=np.float)}), ('select_output', {'shape': np.array(select_output_shape), 'value': np.zeros(select_output_shape)})]) tested_class = Select(graph=graph, attrs={}) node = Node(graph, 'select') tested_class.infer(node) (flag, resp) = compare_graphs(graph, graph_ref, 'select_output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_in_port_with_data(self): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges) new_input_shape = np.array([1, 2, 3, 4]) graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges[1:], new_nodes_with_attrs=[ ('input_node', { 'kind': 'op', 'op': 'Parameter', 'shape': new_input_shape }), ('input_data', { 'kind': 'data' }) ], new_edges_with_attrs=[ ('input_node', 'input_data', { 'in': 0, 'out': 0 }), ('input_data', 'op_node', { 'in': 1, 'out': 0 }) ]) add_input_op(graph, 'op_node', 1, data=True, shape=new_input_shape) graph.remove_edge('future_input', 'op_node') (flag, resp) = compare_graphs(graph, graph_ref, last_node='op_node') self.assertTrue(flag, resp)
def test_4D_multiple_consumers(self): input_shape = int64_array([1, 300, 300, 3]) axes = int64_array([1, 2, 3]) weights_value = np.ones(shape=int64_array([input_shape[-1]]), dtype=np.float32) graph = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ('result_2', dict(kind='op', op='Result')) ], edges + [('input_data', 'result_2')], nodes_with_edges_only=True) graph.stage = 'middle' L2NormToNorm().find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('weights_node_data', dict(kind='data', value=axes.sort())), ('result_2', dict(kind='op', op='Result')) ], edges_after_replacement + [('input_data', 'result_2')], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name') self.assertTrue(flag, resp)
def test_out_port_no_data(self): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out) new_input_shape = np.array([1, 2, 3, 4]) graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes_out, edges_with_attrs=self.edges_out[1:], new_nodes_with_attrs=[ ('input_node', { 'kind': 'op', 'op': 'Parameter', 'shape': new_input_shape }) ], new_edges_with_attrs=[ ('input_node', 'future_input', { 'in': 0, 'out': 0 }) ]) add_input_op(graph, 'op_node', 1, data=False, shape=new_input_shape, is_out_port=True) graph.remove_edge('op_node', 'future_input') (flag, resp) = compare_graphs(graph, graph_ref, last_node='another_node') self.assertTrue(flag, resp) (flag, resp) = compare_graphs(graph, graph_ref, last_node='future_input') self.assertTrue(flag, resp)
def test_merge_infer_simple_case_one_executable(self): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges) # We should propagate value of the first input since only this input is executable graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[ ('merge_output', { 'shape': np.array([2, 2]), 'value': np.ones((2, 2)) }), ('merge', { 'is_not_fully_inferred': False }) ]) tested_class = Merge(graph=graph, attrs={}) node = Node(graph, 'merge') tested_class.merge_infer(node) (flag, resp) = compare_graphs(graph, graph_ref, 'merge_output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_force_precision_parameter(self): precision = 'FP16' shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('data_node', {'shape': shape, 'value': data, 'force_precision': precision})] ) graph_ref = build_graph_with_attrs( nodes_with_attrs=self.nodes + self.new_nodes, edges_with_attrs=self.edges + self.new_edges, update_nodes_attributes=[('data_node', {'shape': shape, 'value': data}), ('const_data', {'shape': shape, 'value': data, 'force_precision': precision}), ('const', {'force_precision': precision})] ) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node') self.assertTrue(flag, resp) #check that force precision was added to data and Const nodes force_precision_const_node = graph.nodes['data_node_const']['force_precision'] force_precision_new_data = graph.nodes['data_node_copy_']['force_precision'] self.assertEqual(force_precision_const_node, precision) self.assertEqual(force_precision_new_data, precision)
def test_switch_infer_no_condition(self): nodes = [('tensor', { 'value': None, 'kind': 'data', 'executable': True, 'shape': np.array([1, 2, 1]) }), ('pred_id', { 'value': None, 'kind': 'data', 'executable': True }), ('switch', { 'type': 'Switch', 'kind': 'op', 'op': 'Switch' }), ('switch_data_0', { 'value': None, 'kind': 'data', 'executable': True }), ('switch_data_1', { 'value': None, 'kind': 'data', 'executable': True })] edges = [('tensor', 'switch', { 'in': 0 }), ('pred_id', 'switch', { 'in': 1 }), ('switch', 'switch_data_0', { 'out': 0 }), ('switch', 'switch_data_1', { 'out': 1 })] graph = build_graph_with_attrs(nodes_with_attrs=nodes, edges_with_attrs=edges) # We should propagate only shapes graph_ref = build_graph_with_attrs(nodes_with_attrs=nodes, edges_with_attrs=edges, update_nodes_attributes=[ ('switch_data_0', { 'shape': np.array([1, 2, 1]) }), ('switch_data_1', { 'shape': np.array([1, 2, 1]) }) ]) tested_class = Switch(graph=graph, attrs={}) node = Node(graph, 'switch') tested_class.infer(node) (flag, resp) = compare_graphs(graph, graph_ref, 'switch_data_0', check_op_attrs=True) self.assertTrue(flag, resp)
def test_positive_matmul_infer(self, A_shape, B_shape, C_shape, transpose_a, transpose_b): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[ ('A_d', { 'shape': int64_array(A_shape) }), ('B_d', { 'shape': int64_array(B_shape) }), ('mat_mul', { 'transpose_a': transpose_a, 'transpose_b': transpose_b }), ]) node = Node(graph, 'mat_mul') MatMul.infer(node) msg = "MatMul infer failed for case: A_shape={}, B_shape={}, transpose_a={}, transpose_b={}" \ "expexted_shape={}, actual_shape={}" self.assertTrue( np.array_equal(graph.node['mat_mul_d']['shape'], int64_array(C_shape)), msg.format(A_shape, B_shape, transpose_a, transpose_b, C_shape, graph.node['mat_mul_d']['shape']))
def test_two_nodes_with_bin(self): """Test case for data node with 2 consumers with bin edge attr. Nothing should happened.""" shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes + [('next_node_2', { 'kind': 'op' })], edges_with_attrs=self.edges + [('data_node', 'next_node_2')], update_nodes_attributes=[('data_node', { 'shape': shape, 'value': data })], update_edge_attrs={ ('data_node', 'next_node', 0): { 'bin': 0 }, ('data_node', 'next_node_2', 0): { 'bin': 0 } }, ) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph, last_node='next_node') self.assertTrue(flag, resp)
def test_only_consumer(self): """Result node is only consumer of Const data node""" nodes = [ ('const_node', {'type': 'Const', 'kind': 'op'}), ('const_data', {'kind': 'data', 'value': np.array(5)}), ('result_node', {'type': 'Result', 'kind': 'op'}), ('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}), ('placeholder_1_data', {'kind': 'data'}), ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}), ('relu_1_data', {'kind': 'data'}), ] edges = [ ('const_node', 'const_data'), ('const_data', 'result_node'), ('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'relu_1'), ('relu_1', 'relu_1_data') ] new_nodes=[ ('placeholder_1', {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'}), ('placeholder_1_data', {'kind': 'data'}), ('relu_1', {'type': 'ReLU', 'kind': 'op', 'op': 'ReLU'}), ('relu_1_data', {'kind': 'data'}), ] new_edges=[ ('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'relu_1'), ('relu_1', 'relu_1_data') ] graph = build_graph_with_attrs( nodes_with_attrs=nodes, edges_with_attrs=edges, ) graph_ref = build_graph_with_attrs( nodes_with_attrs=new_nodes, edges_with_attrs=new_edges, ) tested_pattern = RemoveConstToResult() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1_data') self.assertTrue(flag, resp) self.assertNotIn('const_node', graph.node) self.assertNotIn('const_data', graph.node) self.assertNotIn('result_node', graph.node)
def test_single_consumer(self): graph = build_graph_with_attrs(nodes, edges, nodes_with_edges_only=True) graph.stage = 'middle' L2NormToNorm().find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes, edges_after_replacement, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(graph.node[graph.get_nodes_with_attributes( type='Normalize')[0]]['name'] == 'l2_norm_name') self.assertTrue(flag, resp)
def test_not_dynamic(self): pattern_matcher = LoopConditionMatcher() pattern = pattern_matcher.pattern() graph = build_graph_with_attrs(nodes_with_attrs=pattern['nodes'], edges_with_attrs=pattern['edges'], new_nodes_with_attrs=[('maximum', {'kind': 'op', 'op': 'Maximum'}), ('maximum_data', {'kind': 'data'}), ('TensorIteratorInput', {'kind': 'op', 'op': 'TensorIteratorInput'})], new_edges_with_attrs=[('maximum', 'maximum_data'), ('Identity_1_data', 'TensorIteratorInput')], update_nodes_attributes=[('init_1_data', {'value': np.array([0])}), ('init_2_data', {'value': np.array([0])}), ('add_1_y_data', {'value': np.array(1)}), ('add_2_y_data', {'value': np.array(1)}), ('loop_cond_data', {'value': None}), ('Identity_2_data', {'value': None}, ), ('Enter_1_less_data', {'value': None},), ('Enter_2_less_data', {'value': None},), ]) pattern_matcher.find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs( nodes_with_attrs=[('TensorIteratorCondition', {'kind': 'op', 'op': 'TensorIteratorCondition'}), ('loop_cond_data', {'kind': 'data'}), ('identity_data', {'kind': 'data'}), ('StridedSlice', {'kind': 'op', 'op':'StridedSlice'}), ('StridedSlice_data', {'kind': 'data'}), ('Maximum', {'kind': 'op', 'op': 'Maximum'}), ('Maximum_data', {'kind': 'data'}), ('minimum_data', {'kind': 'data'}), ('TensorIteratorInput', {'kind': 'op', 'op': 'TensorIteratorInput'}) ], edges_with_attrs=[('Maximum', 'Maximum_data'), ('StridedSlice', 'StridedSlice_data'), ('StridedSlice_data', 'TensorIteratorCondition', {'in':0}), ('minimum_data', 'TensorIteratorCondition', {'in':1}), ('TensorIteratorCondition', 'loop_cond_data'), ('TensorIteratorCondition', 'identity_data'), ('identity_data', 'TensorIteratorInput'), ], update_edge_attrs=None, new_nodes_with_attrs=[], new_edges_with_attrs=[], ) (flag, resp) = compare_graphs(graph, graph_ref, 'loop_cond_data', check_op_attrs=True) self.assertTrue(flag, resp)
def test_from4D_to3D(self): input_shape = np.array([1, 2, 3, 4]) new_shape = np.array([3, 4, 2]) nhwc_shape = np.array([1, 3, 4, 2]) graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[ ('input_data', { 'shape': input_shape }), ('reshape', { 'dim': new_shape }), ('reshape_data', { 'shape': new_shape }) ]) graph.graph['layout'] = 'NHWC' # add permute attrs to reshape reshape = Node(graph, 'reshape') PermuteAttrs.create_permute_attrs(reshape, attrs=[('dim', 'output:0')]) tested_pattern = PermuteForReshape() tested_pattern.find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs( nodes_with_attrs=self.nodes + self.permute_nodes, edges_with_attrs=self.edges[1:] + self.permute_edges, update_nodes_attributes=[('input_data', { 'shape': input_shape }), ('reshape', { 'dim': new_shape }), ('reshape_data', { 'shape': new_shape }), ('permute_data', { 'shape': nhwc_shape })]) # check graphs equality (flag, resp) = compare_graphs(graph, graph_ref, last_node='reshape_data') self.assertTrue(flag, resp) # check righ order in new permutation node permute_order = graph.node['reshape/Permute_']['order'] self.assertTrue(np.all( permute_order == np.array([0, 2, 3, 1]))) # from NCHW to NHWC
def test_select_infer_assert_shapes(self): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('else_data', {'shape': np.array([3, 3]), 'value':np.zeros((3, 3))})]) tested_class = Select(graph=graph, attrs={}) node = Node(graph, 'select') with self.assertRaisesRegex(AssertionError, "Input shape do not broadcast"): tested_class.infer(node)
def test_merge_infer_complex_case(self): """ Case as in cycles when in first visit only one input are inferred and in the second -- both. """ graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('first', {'is_partial_inferred': False, 'value': None}), ('second', {'executable': True})]) # In first visit we should propagate only shapes graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('second', {'executable': True}), ('first', {'is_partial_inferred': False, 'value': None}), ('merge_output', {'shape': np.array([2, 2]), 'value': None}), ('merge', {'is_not_fully_inferred': True})]) tested_class = Merge(graph=graph, attrs={}) node = Node(graph, 'merge') tested_class.merge_infer(node) (flag, resp) = compare_graphs(graph, graph_ref, 'merge_output', check_op_attrs=True) self.assertTrue(flag, resp) # Imitate that inputs nodes now is inferred graph.node['first']['is_partial_inferred'] = True # Run infer second time tested_class = Merge(graph=graph, attrs={}) node = Node(graph, 'merge') tested_class.merge_infer(node) graph_ref = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('second', {'executable': True}), ('first', {'is_partial_inferred': True, 'value': None}), ('merge_output', {'shape': np.array([2, 2]), 'value': None}), ('merge', {'is_not_fully_inferred': False})]) (flag, resp) = compare_graphs(graph, graph_ref, 'merge_output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_1(self): """ Acyclic case => graph.graph['is_cyclic'] should be False. """ graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges) tested_pass = AddIsCyclicAttribute() tested_pass.find_and_replace_pattern(graph) assert graph.graph['is_cyclic'] is False
def test_one_node(self): """We should add Const node and data node.""" shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})] ) graph_ref = build_graph_with_attrs( nodes_with_attrs=self.nodes + self.new_nodes, edges_with_attrs=self.edges + self.new_edges, update_nodes_attributes=[('data_node', {'shape': shape, 'value': data}), ('const_data', {'shape': shape, 'value': data})] ) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node') self.assertTrue(flag, resp)
def test_select_infer_assert_shapes(self): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('else', {'shape': np.array([3,3]), 'value':np.zeros((3,3))})]) tested_class = Select(graph=graph, attrs={}) node = Node(graph, 'select') with self.assertRaisesRegex(AssertionError, "TensorFlow \'Select\' operation has 3 inputs: \'condition\'," " \'then\' and \'else\' tensors.\'then\' and \'else\' tensors" " must have the same shape by TensorFlow reference"): tested_class.infer(node)
def test_select_infer_assert_condition_bool(self): graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('condition', {'value': np.array([3])})]) tested_class = Select(graph=graph, attrs={}) node = Node(graph, 'select') with self.assertRaisesRegex(AssertionError, "TensorFlow \'Select\' operation has 3 inputs: \'condition\'," " \'then\' and \'else\' tensors. Value of \'condition\' tensor" " must be boolen by TensorFlow reference"): tested_class.infer(node)
def test(self): pattern_matcher = MVNUnrolled() pattern = pattern_matcher.pattern() graph = build_graph_with_attrs(nodes_with_attrs=pattern['nodes'], edges_with_attrs=pattern['edges'], update_edge_attrs=None, new_nodes_with_attrs=[('reduction_indicies', {'kind': 'data'}), ('conv2d', {'kind': 'op'}), ('variance_reduction', {'kind': 'data'}), ('pow2', {'kind': 'data'}), ('eps', {'kind': 'data'}), ('next_op', {'kind': 'op'})], new_edges_with_attrs=[('reduction_indicies', 'mean', {'in': 1}), ('conv2d', 'mean',{'in': 0, 'out': 1}), ('variance_reduction', 'variance', {'in': 1}), ('pow2', 'pow', {'in': 1}), ('eps', 'add'), ('truediv', 'next_op')]) graph.graph['layout'] = 'NHWC' pattern_matcher.find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes_with_attrs=pattern['nodes'][:-1], edges_with_attrs=pattern['edges'][:-2], update_edge_attrs=None, new_nodes_with_attrs=[('reduction_indicies', {'kind':'data'}), ('conv2d', {'kind':'op'}), ('variance_reduction', {'kind':'data'}), ('pow2', {'kind': 'data'}), ('eps', {'kind': 'data'}), ('mvn', {'kind': 'op', 'op': 'MVN'}), ('next_op', {'kind': 'op'})], new_edges_with_attrs=[('reduction_indicies', 'mean', {'in':1}), ('conv2d', 'mean', {'in': 0}), ('variance_reduction', 'variance',{'in': 1}), ('pow2', 'pow', {'in': 1}), ('eps', 'add'), ('conv2d', 'mvn',{'in': 0}), ('reduction_indicies', 'mvn', {'in': 1}), ('variance_reduction', 'mvn',{'in': 2}), ('pow2', 'mvn', {'in': 3}), ('eps', 'mvn',{'in': 4}), ('mvn', 'next_op')]) (flag, resp) = compare_graphs(graph, graph_ref, 'next_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test_multiple_consumers(self): graph = build_graph_with_attrs( nodes + [('result_2', dict(kind='op', op='Result'))], edges + [('input_data', 'result_2')], nodes_with_edges_only=True) graph.stage = 'middle' L2NormToNorm().find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs( nodes + [('result_2', dict(kind='op', op='Result'))], edges_after_replacement + [('input_data', 'result_2')], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(graph.node[graph.get_nodes_with_attributes( type='Normalize')[0]]['name'] == 'l2_norm_name') self.assertTrue(flag, resp)
def test_2(self): """ Cyclic case => graph.graph['is_cyclic'] should be True. :return: """ graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, new_edges_with_attrs=[('node_2', 'node_1')]) tested_pass = AddIsCyclicAttribute() tested_pass.find_and_replace_pattern(graph) assert graph.graph['is_cyclic'] is True
def test_positive(self, input_shape, axes, layout): graph = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) graph.stage = 'middle' graph.graph['layout'] = layout L2NormToNorm().find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('weights_node_data', dict(kind='data', value=axes.sort())), ], edges_after_replacement, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(graph.node[graph.get_nodes_with_attributes( type='NormalizeL2')[0]]['name'] == 'l2_norm_name') self.assertTrue(flag, resp)
def test_negative(self, input_shape, axes, layout): graph = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) graph.stage = 'middle' graph.graph['layout'] = layout L2NormToNorm().find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp)
def test_4D_negative_4(self): input_shape = int64_array([1, 300, 300, 3]) axes = int64_array([2, 0]) graph = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) graph.stage = 'middle' L2NormToNorm().find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp)