def test_two_nodes_with_bin(self): """Test case for data node with 2 consumers with bin edge attr. Nothing should happened.""" shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes + [('next_node_2', { 'kind': 'op' })], edges_with_attrs=self.edges + [('data_node', 'next_node_2')], update_nodes_attributes=[('data_node', { 'shape': shape, 'value': data })], update_edge_attrs={ ('data_node', 'next_node', 0): { 'bin': 0 }, ('data_node', 'next_node_2', 0): { 'bin': 0 } }, ) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph, last_node='next_node') self.assertTrue(flag, resp)
def test_two_nodes_one_bin(self): """Test case for two output nodes, one with 'bin' parameter, other without.""" shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes + [('next_node_2', { 'kind': 'op' })], edges_with_attrs=self.edges + [('data_node', 'next_node_2')], update_nodes_attributes=[('data_node', { 'shape': shape, 'value': data })], update_edge_attrs={('data_node', 'next_node', 0): { 'bin': 0 }}, ) graph_ref = build_graph_with_attrs( nodes_with_attrs=self.nodes + self.new_nodes + [('next_node_2', { 'kind': 'op' })], edges_with_attrs=self.edges + self.new_edges + [('data_node', 'next_node_2')], update_nodes_attributes=[('data_node', { 'shape': shape, 'value': data }), ('const_data', { 'shape': shape, 'value': data })]) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node') self.assertTrue(flag, resp)
def test_one_bin_node(self): """Nothing should happen.""" shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})], update_edge_attrs={('data_node', 'next_node', 0): {'bin': 0}}, ) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph, last_node='next_node') self.assertTrue(flag, resp)
def test_one_node(self): """We should add Const node and data node.""" shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs( nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})] ) graph_ref = build_graph_with_attrs( nodes_with_attrs=self.nodes + self.new_nodes, edges_with_attrs=self.edges + self.new_edges, update_nodes_attributes=[('data_node', {'shape': shape, 'value': data}), ('const_data', {'shape': shape, 'value': data})] ) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node') self.assertTrue(flag, resp)
def test_force_precision_parameter(self): precision = 'FP16' shape = np.array([2, 3, 4]) data = np.zeros(shape) graph = build_graph_with_attrs(nodes_with_attrs=self.nodes, edges_with_attrs=self.edges, update_nodes_attributes=[('data_node', { 'shape': shape, 'value': data, 'force_precision': precision })]) graph_ref = build_graph_with_attrs( nodes_with_attrs=self.nodes + self.new_nodes, edges_with_attrs=self.edges + self.new_edges, update_nodes_attributes=[('data_node', { 'shape': shape, 'value': data }), ('const_data', { 'shape': shape, 'value': data, 'force_precision': precision }), ('const', { 'force_precision': precision })]) tested_pattern = CreateConstNodesReplacement() tested_pattern.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node') self.assertTrue(flag, resp) #check that force precision was added to data and Const nodes force_precision_const_node = graph.nodes['data_node_const'][ 'force_precision'] force_precision_new_data = graph.nodes['data_node_copy_'][ 'force_precision'] self.assertEqual(force_precision_const_node, precision) self.assertEqual(force_precision_new_data, precision)