def test_Conv_runs(channels, filter_shape, padding, strides, input_shape, dilation): conv = Conv(channels, filter_shape, strides=strides, padding=padding, dilation=dilation) inputs = random_inputs(input_shape) params = conv.init_parameters(PRNGKey(0), inputs) conv.apply(params, inputs)
def test_flatten_shape(): conv = Conv(2, filter_shape=(3, 3), padding='SAME', kernel_init=zeros, bias_init=zeros) inputs = np.zeros((1, 5, 5, 2)) params = conv.init_parameters(PRNGKey(0), inputs) out = conv.apply(params, inputs) assert np.array_equal(np.zeros((1, 5, 5, 2)), out) flattened = Sequential(conv, flatten) out = flattened.apply_from({conv: params}, inputs) assert np.array_equal(np.zeros((1, 50)), out)