def test_neg_reciprocal_1(self): # Test if power = 0 graph = build_graph(nodes_attributes, [('placeholder_1', 'reciprocal_1'), ('reciprocal_1', 'last') ], {'placeholder_1': {'shape': np.array([1, 227, 227, 3])}, }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, [('placeholder_1', 'pow'), ('const', 'pow', {'in': 1}), ('pow', 'last'), ], {'placeholder_1': {'shape': np.array([1, 227, 227, 3])}, 'const': {'value': np.array(0)}, }, nodes_with_edges_only=True) graph.stage = 'front' pattern = ReciprocalReplacer() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True) self.assertTrue(not flag)
def test_replace_reciprocal(self): graph, graph_ref = __class__._create_graphs() pattern = ReciprocalReplacer() pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'reciprocal/power_', last_node_ref='power', check_op_attrs=True) self.assertTrue(flag, resp)