def test_ScatterElementsUpdate_has_axis_and_4_inputs(self): graph = build_graph(nodes, [ *edges, *connect('axis', '3:node'), ], { 'node': { 'axis': 1 }, 'axis': { 'value': np.int64(1) } }, nodes_with_edges_only=True) self.assertRaises(AssertionError, ScatterNormalizer().find_and_replace_pattern, graph)
def test_ScatterElementsUpdate_has_axis_and_3_inputs(self): graph = build_graph(nodes, edges, {'node': { 'axis': 1 }}, nodes_with_edges_only=True) ScatterNormalizer().find_and_replace_pattern(graph) graph_ref = build_graph(nodes, [ *edges, *connect('axis', '3:node'), ], {'axis': { 'value': np.int64(1) }}, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_ScatterElementsUpdate_has_no_axis_and_3_inputs(self): graph = build_graph(nodes, edges, nodes_with_edges_only=True) self.assertRaises(AssertionError, ScatterNormalizer().find_and_replace_pattern, graph)