def test_group_convolution_backprop_data_output_shape(): runtime = get_runtime() data_shape = [1, 1, 1, 10] filters_shape = [1, 1, 1, 1, 5] strides = [1, 1] data_node = ng.parameter(data_shape, name='Data', dtype=np.float32) filters_node = ng.parameter(filters_shape, name='Filters', dtype=np.float32) output_shape_node = ng.constant(np.array([1, 14], dtype=np.int64)) model = ng.group_convolution_backprop_data(data_node, filters_node, strides, output_shape_node, auto_pad='same_upper') data_value = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], dtype=np.float32).reshape(data_shape) filters_value = np.array([1.0, 2.0, 3.0, 2.0, 1.0], dtype=np.float32).reshape(filters_shape) computation = runtime.computation(model, data_node, filters_node) result = computation(data_value, filters_value) expected = np.array([ 0.0, 1.0, 4.0, 10.0, 18.0, 27.0, 36.0, 45.0, 54.0, 63.0, 62.0, 50.0, 26.0, 9.0 ], dtype=np.float32).reshape(1, 1, 1, 14) assert np.allclose(result, expected)
def test_group_convolution_backprop_data(): runtime = get_runtime() data_shape = [1, 1, 3, 3] filters_shape = [1, 1, 1, 3, 3] strides = [2, 2] output_padding = [1, 1] pads_begin = [1, 1] pads_end = [1, 1] data_node = ng.parameter(data_shape, name='Data', dtype=np.float32) filters_node = ng.parameter(filters_shape, name='Filters', dtype=np.float32) model = ng.group_convolution_backprop_data(data_node, filters_node, strides, None, pads_begin, pads_end, output_padding=output_padding) data_value = np.array([ 0.16857791, -0.15161794, 0.08540368, 0.1820628, -0.21746576, 0.08245695, 0.1431433, -0.43156421, 0.30591947 ], dtype=np.float32).reshape(data_shape) filters_value = np.array([ -0.06230065, 0.37932432, -0.25388849, 0.33878803, 0.43709868, -0.22477469, 0.04118127, -0.44696793, 0.06373066 ], dtype=np.float32).reshape(filters_shape) computation = runtime.computation(model, data_node, filters_node) result = computation(data_value, filters_value) expected = np.array([ 0.07368518, -0.08925839, -0.06627201, 0.06301362, 0.03732984, -0.01919658, -0.00628807, -0.02817563, -0.01472169, 0.04392925, -0.00689478, -0.01549204, 0.07957941, -0.11459791, -0.09505399, 0.07681622, 0.03604182, -0.01853423, -0.0270785, -0.00680824, -0.06650258, 0.08004665, 0.07918708, 0.0724144, 0.06256775, -0.17838378, -0.18863615, 0.20064656, 0.133717, -0.06876295, -0.06398046, -0.00864975, 0.19289537, -0.01490572, -0.13673618, 0.01949645 ], dtype=np.float32).reshape(1, 1, 6, 6) assert np.allclose(result, expected)