def test_ScatterElementsUpdate_has_axis_and_4_inputs(self):
     graph = build_graph(nodes, [
         *edges,
         *connect('axis', '3:node'),
     ], {
         'node': {
             'axis': 1
         },
         'axis': {
             'value': np.int64(1)
         }
     },
                         nodes_with_edges_only=True)
     self.assertRaises(AssertionError,
                       ScatterNormalizer().find_and_replace_pattern, graph)
    def test_ScatterElementsUpdate_has_axis_and_3_inputs(self):
        graph = build_graph(nodes,
                            edges, {'node': {
                                'axis': 1
                            }},
                            nodes_with_edges_only=True)
        ScatterNormalizer().find_and_replace_pattern(graph)

        graph_ref = build_graph(nodes, [
            *edges,
            *connect('axis', '3:node'),
        ], {'axis': {
            'value': np.int64(1)
        }},
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
 def test_ScatterElementsUpdate_has_no_axis_and_3_inputs(self):
     graph = build_graph(nodes, edges, nodes_with_edges_only=True)
     self.assertRaises(AssertionError,
                       ScatterNormalizer().find_and_replace_pattern, graph)