示例#1
0
    def test_neg_reciprocal_1(self):
        # Test if power = 0

        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'reciprocal_1'),
                             ('reciprocal_1', 'last')
                             ],
                            {'placeholder_1': {'shape': np.array([1, 227, 227, 3])},
                             }, nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder_1', 'pow'),
                                 ('const', 'pow', {'in': 1}),
                                 ('pow', 'last'),
                                 ],
                                {'placeholder_1': {'shape': np.array([1, 227, 227, 3])},
                                 'const': {'value': np.array(0)},
                                 }, nodes_with_edges_only=True)

        graph.stage = 'front'

        pattern = ReciprocalReplacer()
        pattern.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
        self.assertTrue(not flag)
示例#2
0
    def test_replace_reciprocal(self):
        graph, graph_ref = __class__._create_graphs()

        pattern = ReciprocalReplacer()
        pattern.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'reciprocal/power_',
                                      last_node_ref='power',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)