Exemple #1
0
def test_Conv_runs(channels, filter_shape, padding, strides, input_shape,
                   dilation):
    conv = Conv(channels,
                filter_shape,
                strides=strides,
                padding=padding,
                dilation=dilation)
    inputs = random_inputs(input_shape)
    params = conv.init_parameters(PRNGKey(0), inputs)
    conv.apply(params, inputs)
Exemple #2
0
def test_flatten_shape():
    conv = Conv(2,
                filter_shape=(3, 3),
                padding='SAME',
                kernel_init=zeros,
                bias_init=zeros)
    inputs = np.zeros((1, 5, 5, 2))

    params = conv.init_parameters(PRNGKey(0), inputs)
    out = conv.apply(params, inputs)
    assert np.array_equal(np.zeros((1, 5, 5, 2)), out)

    flattened = Sequential(conv, flatten)
    out = flattened.apply_from({conv: params}, inputs)
    assert np.array_equal(np.zeros((1, 50)), out)