def test_avg_pool_same_padding_with_inferred_shapes(self): x = np.ones((2, 3, 6)) result = pool.avg_pool(x, 3, 1, padding="SAME", channel_axis=None) np.testing.assert_equal(result.shape, x.shape) # Since x is constant, its avg value should be itself. np.testing.assert_equal(result, x)
def test_avg_pool_basic_with_inferred_shapes(self): x = np.arange(6, dtype=jnp.float32).reshape([6, 1]) x = np.broadcast_to(x, (2, 10, 6, 2)) result = pool.avg_pool(x, 2, 2, padding="VALID") ground_truth = np.asarray([0.5, 2.5, 4.5]).reshape([3, 1]) ground_truth = np.broadcast_to(ground_truth, (2, 5, 3, 2)) np.testing.assert_equal(result, ground_truth)
def test_avg_pool_same_padding(self): x = np.ones((2, 3, 6)) window_shape = [1, 3, 3] strides = [1, 1, 1] result = pool.avg_pool( x, window_shape=window_shape, strides=strides, padding="SAME") np.testing.assert_equal(result.shape, x.shape) # Since x is constant, its avg value should be itself. np.testing.assert_equal(result, x)
def test_avg_pool_basic(self): x = np.arange(6, dtype=jnp.float32).reshape([6, 1]) x = np.broadcast_to(x, (2, 10, 6, 2)) window_shape = [1, 2, 2, 1] result = pool.avg_pool( x, window_shape=window_shape, strides=window_shape, padding="VALID") ground_truth = np.asarray([0.5, 2.5, 4.5]).reshape([3, 1]) ground_truth = np.broadcast_to(ground_truth, (2, 5, 3, 2)) np.testing.assert_equal(result, ground_truth)
def test_avg_pool_unbatched(self): x = np.arange(6, dtype=jnp.float32).reshape([6, 1]) leading_dims = (2, 3) x = np.broadcast_to(x, leading_dims + (10, 6, 2)) window_shape = [2, 2, 1] result = pool.avg_pool(x, window_shape=window_shape, strides=window_shape, padding="VALID") ground_truth = np.asarray([0.5, 2.5, 4.5]).reshape([3, 1]) ground_truth = np.broadcast_to(ground_truth, leading_dims + (5, 3, 2)) np.testing.assert_equal(result, ground_truth)
def test_avg_pool_overlapping_windows(self): x = np.arange(12, dtype=jnp.float32).reshape([6, 2]) x = np.broadcast_to(x, (2, 10, 6, 2)) window_shape = [1, 5, 3, 2] strides = [1, 1, 3, 2] result = pool.avg_pool( x, window_shape=window_shape, strides=strides, padding="VALID") ground_truth = np.asarray([ 2.5, 8.5, ]).reshape([2, 1]) ground_truth = np.broadcast_to(ground_truth, (2, 6, 2, 1)) np.testing.assert_almost_equal(result, ground_truth, decimal=5)