def test_two_nodes_with_bin(self):
     """Test case for data node with 2 consumers with bin edge attr.
     Nothing should happened."""
     shape = np.array([2, 3, 4])
     data = np.zeros(shape)
     graph = build_graph_with_attrs(
         nodes_with_attrs=self.nodes + [('next_node_2', {
             'kind': 'op'
         })],
         edges_with_attrs=self.edges + [('data_node', 'next_node_2')],
         update_nodes_attributes=[('data_node', {
             'shape': shape,
             'value': data
         })],
         update_edge_attrs={
             ('data_node', 'next_node', 0): {
                 'bin': 0
             },
             ('data_node', 'next_node_2', 0): {
                 'bin': 0
             }
         },
     )
     tested_pattern = CreateConstNodesReplacement()
     tested_pattern.find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, graph, last_node='next_node')
     self.assertTrue(flag, resp)
 def test_two_nodes_one_bin(self):
     """Test case for two output nodes, one with 'bin' parameter, other without."""
     shape = np.array([2, 3, 4])
     data = np.zeros(shape)
     graph = build_graph_with_attrs(
         nodes_with_attrs=self.nodes + [('next_node_2', {
             'kind': 'op'
         })],
         edges_with_attrs=self.edges + [('data_node', 'next_node_2')],
         update_nodes_attributes=[('data_node', {
             'shape': shape,
             'value': data
         })],
         update_edge_attrs={('data_node', 'next_node', 0): {
                                'bin': 0
                            }},
     )
     graph_ref = build_graph_with_attrs(
         nodes_with_attrs=self.nodes + self.new_nodes + [('next_node_2', {
             'kind': 'op'
         })],
         edges_with_attrs=self.edges + self.new_edges +
         [('data_node', 'next_node_2')],
         update_nodes_attributes=[('data_node', {
             'shape': shape,
             'value': data
         }), ('const_data', {
             'shape': shape,
             'value': data
         })])
     tested_pattern = CreateConstNodesReplacement()
     tested_pattern.find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node')
     self.assertTrue(flag, resp)
 def test_one_bin_node(self):
     """Nothing should happen."""
     shape = np.array([2, 3, 4])
     data = np.zeros(shape)
     graph = build_graph_with_attrs(
         nodes_with_attrs=self.nodes,
         edges_with_attrs=self.edges,
         update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})],
         update_edge_attrs={('data_node', 'next_node', 0): {'bin': 0}},
     )
     tested_pattern = CreateConstNodesReplacement()
     tested_pattern.find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, graph, last_node='next_node')
     self.assertTrue(flag, resp)
 def test_one_node(self):
     """We should add Const node and data node."""
     shape = np.array([2, 3, 4])
     data = np.zeros(shape)
     graph = build_graph_with_attrs(
         nodes_with_attrs=self.nodes,
         edges_with_attrs=self.edges,
         update_nodes_attributes=[('data_node', {'shape': shape, 'value': data})]
     )
     graph_ref = build_graph_with_attrs(
         nodes_with_attrs=self.nodes + self.new_nodes,
         edges_with_attrs=self.edges + self.new_edges,
         update_nodes_attributes=[('data_node', {'shape': shape, 'value': data}),
                                  ('const_data', {'shape': shape, 'value': data})]
     )
     tested_pattern = CreateConstNodesReplacement()
     tested_pattern.find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node')
     self.assertTrue(flag, resp)
    def test_force_precision_parameter(self):
        precision = 'FP16'
        shape = np.array([2, 3, 4])
        data = np.zeros(shape)
        graph = build_graph_with_attrs(nodes_with_attrs=self.nodes,
                                       edges_with_attrs=self.edges,
                                       update_nodes_attributes=[('data_node', {
                                           'shape':
                                           shape,
                                           'value':
                                           data,
                                           'force_precision':
                                           precision
                                       })])
        graph_ref = build_graph_with_attrs(
            nodes_with_attrs=self.nodes + self.new_nodes,
            edges_with_attrs=self.edges + self.new_edges,
            update_nodes_attributes=[('data_node', {
                'shape': shape,
                'value': data
            }),
                                     ('const_data', {
                                         'shape': shape,
                                         'value': data,
                                         'force_precision': precision
                                     }),
                                     ('const', {
                                         'force_precision': precision
                                     })])
        tested_pattern = CreateConstNodesReplacement()
        tested_pattern.find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, graph_ref, last_node='next_node')
        self.assertTrue(flag, resp)

        #check that force precision was added to data and Const nodes
        force_precision_const_node = graph.nodes['data_node_const'][
            'force_precision']
        force_precision_new_data = graph.nodes['data_node_copy_'][
            'force_precision']
        self.assertEqual(force_precision_const_node, precision)
        self.assertEqual(force_precision_new_data, precision)