def test_4D_multiple_consumers(self): input_shape = int64_array([1, 300, 300, 3]) axes = int64_array([1, 2, 3]) weights_value = np.ones(shape=int64_array([input_shape[-1]]), dtype=np.float32) graph = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ('result_2', dict(kind='op', op='Result')) ], edges + [('input_data', 'result_2')], nodes_with_edges_only=True) graph.stage = 'middle' L2NormToNorm().find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('weights_node_data', dict(kind='data', value=axes.sort())), ('result_2', dict(kind='op', op='Result')) ], edges_after_replacement + [('input_data', 'result_2')], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name') self.assertTrue(flag, resp)
def test_positive(self, input_shape, axes, layout): graph = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) graph.stage = 'middle' graph.graph['layout'] = layout L2NormToNorm().find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('weights_node_data', dict(kind='data', value=axes.sort())), ], edges_after_replacement, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name') self.assertTrue(flag, resp)
def test_negative(self, input_shape, axes, layout): graph = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) graph.stage = 'middle' graph.graph['layout'] = layout L2NormToNorm().find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp)
def test_4D_negative_4(self): input_shape = int64_array([1, 300, 300, 3]) axes = int64_array([2, 0]) graph = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) graph.stage = 'middle' L2NormToNorm().find_and_replace_pattern(graph) graph_ref = build_graph_with_attrs(nodes + [ ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)), ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)), ('square_data', dict(kind='data', shape=input_shape)), ('sum_axes_data', dict(kind='data', value=axes, shape=None)), ], edges, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp)