Пример #1
0
    def test_softplus_fusion_test_wrong_const(self):
        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {'const_1': {'value': float_array([0.9999])}})

        graph_ref = graph.copy()
        graph.stage = 'front'

        SoftplusFusion().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)
Пример #2
0
    def test_softplus_fusion_test(self):
        graph = build_graph_with_edge_attrs(self.nodes, self.edges, {})

        graph_ref = build_graph(ref_nodes, ref_edges)
        graph.stage = 'front'

        SoftplusFusion().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)
        self.assertTrue(len(graph.get_op_nodes(name='final_log')) == 1 and
                        graph.get_op_nodes(name='final_log')[0].op == 'SoftPlus')