Exemplo n.º 1
0
def test_group_convolution_operator():
    runtime = get_runtime()

    data_shape = [1, 4, 2, 2]
    filters_shape = [2, 1, 2, 1, 1]

    parameter_data = ng.parameter(data_shape, name="Data", dtype=np.float32)
    parameter_filters = ng.parameter(filters_shape,
                                     name="Filters",
                                     dtype=np.float32)

    data_value = np.arange(start=1.0, stop=17.0,
                           dtype=np.float32).reshape(data_shape)
    filters_value = np.arange(start=1.0, stop=5.0,
                              dtype=np.float32).reshape(filters_shape)
    strides = [1, 1]
    dilations = [1, 1]
    pads_begin = [0, 0]
    pads_end = [0, 0]

    model = ng.group_convolution(parameter_data, parameter_filters, strides,
                                 pads_begin, pads_end, dilations)
    computation = runtime.computation(model, parameter_data, parameter_filters)
    result = computation(data_value, filters_value)

    expected = np.array([11, 14, 17, 20, 79, 86, 93, 100],
                        dtype=np.float32).reshape(1, 2, 2, 2)

    assert np.allclose(result, expected)
Exemplo n.º 2
0
def test_group_convolution_operator():
    runtime = get_runtime()

    data_shape = [1, 4, 2, 2]
    filters_shape = [2, 2, 1, 1]

    parameter_data = ng.parameter(data_shape, name='Data', dtype=np.float32)
    parameter_filters = ng.parameter(filters_shape, name='Filters', dtype=np.float32)

    data_value = np.arange(start=1.0, stop=17.0, dtype=np.float32).reshape(data_shape)
    filters_value = np.arange(start=1.0, stop=5.0, dtype=np.float32).reshape(filters_shape)
    window_movement_strides = [1, 1]
    window_dilation_strides = [1, 1]
    padding_below = [0, 0]
    padding_above = [0, 0]
    data_dilation_strides = [1, 1]
    groups = 2

    model = ng.group_convolution(parameter_data,
                                 parameter_filters,
                                 window_movement_strides,
                                 window_dilation_strides,
                                 padding_below, padding_above,
                                 data_dilation_strides,
                                 groups,
                                 0)
    computation = runtime.computation(model, parameter_data, parameter_filters)

    result = computation(data_value, filters_value)
    expected = np.array([11, 14, 17, 20, 79, 86, 93, 100],
                        dtype=np.float32).reshape(1, 2, 2, 2)

    assert np.allclose(result, expected)