def test_v7_group_convolution_resolver_weight_are_in_the_right_layout( self): nodes = { **regular_op_with_shaped_data('input', None, { 'type': 'Parameter' }), **valued_const_with_data('weights', np.ones([24, 1, 7, 7])), **regular_op_with_shaped_data('convolution', None, { 'type': 'Convolution', 'group': 3, 'output': 24 }), **result(), } edges = [ *connect('input', '0:convolution'), *connect('weights', '1:convolution'), *connect('convolution', 'output'), ] graph = build_graph(nodes, edges) V7ConvolutionWithGroupsResolver().find_and_replace_pattern(graph) graph_ref = build_graph(nodes, edges) (flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_v7_group_convolution_resolver_depthwise_conv2d(self): nodes = { **regular_op_with_shaped_data('input', [1, 1, 224, 224], { 'type': 'Parameter' }), **valued_const_with_data('weights', np.ones([1, 8, 7, 7])), **valued_const_with_data('dim', int64_array([8, -1, 7, 7])), **regular_op_with_empty_data('reshape', {'type': 'Reshape'}), **regular_op_with_shaped_data( 'convolution', None, { 'type': 'Convolution', 'group': 1, 'output': 8, 'op': 'DepthwiseConv2dNative' }), **result(), } graph = build_graph(nodes, [ *connect('input', '0:convolution'), *connect('weights', '1:convolution'), *connect('convolution', 'output'), ], nodes_with_edges_only=True) V7ConvolutionWithGroupsResolver().find_and_replace_pattern(graph) graph_ref = build_graph(nodes, [ *connect('input', '0:convolution'), *connect('weights', '0:reshape'), *connect('dim', '1:reshape'), *connect('reshape', '1:convolution'), *connect('convolution', 'output'), ], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, last_node='output', check_op_attrs=True) self.assertTrue(flag, resp)