def test_v7_group_convolution_resolver_weight_are_in_the_right_layout(
         self):
     nodes = {
         **regular_op_with_shaped_data('input', None, {
             'type': 'Parameter'
         }),
         **valued_const_with_data('weights', np.ones([24, 1, 7, 7])),
         **regular_op_with_shaped_data('convolution', None, {
             'type': 'Convolution',
             'group': 3,
             'output': 24
         }),
         **result(),
     }
     edges = [
         *connect('input', '0:convolution'),
         *connect('weights', '1:convolution'),
         *connect('convolution', 'output'),
     ]
     graph = build_graph(nodes, edges)
     V7ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
     graph_ref = build_graph(nodes, edges)
     (flag, resp) = compare_graphs(graph,
                                   graph_ref,
                                   last_node='output',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)
    def test_v7_group_convolution_resolver_depthwise_conv2d(self):
        nodes = {
            **regular_op_with_shaped_data('input', [1, 1, 224, 224], {
                                              'type': 'Parameter'
                                          }),
            **valued_const_with_data('weights', np.ones([1, 8, 7, 7])),
            **valued_const_with_data('dim', int64_array([8, -1, 7, 7])),
            **regular_op_with_empty_data('reshape', {'type': 'Reshape'}),
            **regular_op_with_shaped_data(
                'convolution', None, {
                    'type': 'Convolution',
                    'group': 1,
                    'output': 8,
                    'op': 'DepthwiseConv2dNative'
                }),
            **result(),
        }
        graph = build_graph(nodes, [
            *connect('input', '0:convolution'),
            *connect('weights', '1:convolution'),
            *connect('convolution', 'output'),
        ],
                            nodes_with_edges_only=True)

        V7ConvolutionWithGroupsResolver().find_and_replace_pattern(graph)
        graph_ref = build_graph(nodes, [
            *connect('input', '0:convolution'),
            *connect('weights', '0:reshape'),
            *connect('dim', '1:reshape'),
            *connect('reshape', '1:convolution'),
            *connect('convolution', 'output'),
        ],
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      last_node='output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)