def test_max_pool_same_padding_with_inferred_shapes(self): x = np.arange(6, dtype=jnp.float32) x = np.broadcast_to(x, (2, 3, 6)) result = pool.max_pool(x, 3, 1, padding="SAME", channel_axis=None) np.testing.assert_equal(result.shape, x.shape)
def test_max_pool_same_padding(self): x = np.arange(6, dtype=jnp.float32) x = np.broadcast_to(x, (2, 3, 6)) window_shape = [1, 3, 3] strides = [1, 1, 1] result = pool.max_pool( x, window_shape=window_shape, strides=strides, padding="SAME") np.testing.assert_equal(result.shape, x.shape)
def test_max_pool_basic_with_inferred_shapes(self): x = np.arange(6, dtype=jnp.float32).reshape([6, 1]) x = np.broadcast_to(x, (2, 10, 6, 2)) result = pool.max_pool(x, 2, 2, padding="VALID") ground_truth = np.asarray([1., 3., 5.]).reshape([3, 1]) ground_truth = np.broadcast_to(ground_truth, (2, 5, 3, 2)) np.testing.assert_equal(result, ground_truth)
def test_max_pool_basic(self): x = np.arange(6, dtype=jnp.float32).reshape([6, 1]) x = np.broadcast_to(x, (2, 10, 6, 2)) window_shape = [1, 2, 2, 1] result = pool.max_pool( x, window_shape=window_shape, strides=window_shape, padding="VALID") ground_truth = np.asarray([1., 3., 5.]).reshape([3, 1]) ground_truth = np.broadcast_to(ground_truth, (2, 5, 3, 2)) np.testing.assert_equal(result, ground_truth)
def test_avg_pool_same_padding(self): x = np.ones((2, 3, 6)) window_shape = [1, 3, 3] strides = [1, 1, 1] result = pool.max_pool(x, window_shape=window_shape, strides=strides, padding="SAME") np.testing.assert_equal(result.shape, x.shape) # Since x is constant, its avg value should be itself. np.testing.assert_equal(result, x)
def test_max_pool_overlapping_windows(self): x = np.arange(12, dtype=jnp.float32).reshape([6, 2]) x = np.broadcast_to(x, (2, 10, 6, 2)) window_shape = [1, 5, 3, 2] strides = [1, 1, 3, 2] result = pool.max_pool( x, window_shape=window_shape, strides=strides, padding="VALID") ground_truth = np.asarray([5., 11.,]).reshape([2, 1]) ground_truth = np.broadcast_to(ground_truth, (2, 6, 2, 1)) np.testing.assert_equal(result, ground_truth)
def test_max_pool_unbatched(self): x = np.arange(6, dtype=jnp.float32).reshape([6, 1]) leading_dims = (2, 3) x = np.broadcast_to(x, leading_dims + (10, 6, 2)) window_shape = [2, 2, 1] result = pool.max_pool(x, window_shape=window_shape, strides=window_shape, padding="VALID") ground_truth = np.asarray([1., 3., 5.]).reshape([3, 1]) ground_truth = np.broadcast_to(ground_truth, leading_dims + (5, 3, 2)) np.testing.assert_equal(result, ground_truth)
def __call__(self, inputs, is_training): x = inputs x = self._initial_conv(x) if not self._resnet_v2: x = self._initial_batchnorm(x, is_training=is_training) x = jax.nn.relu(x) x = pool.max_pool(x, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME") for block_group in self._block_groups: x = block_group(x, is_training) if self._resnet_v2: x = self._final_batchnorm(x, is_training=is_training) x = jax.nn.relu(x) x = jnp.mean(x, axis=[1, 2]) return self._logits(x)
def __call__(self, inputs, is_training): net = inputs net = self._initial_conv(net) if not self._resnet_v2: net = self._initial_batchnorm(net, is_training=is_training) net = jax.nn.relu(net) net = pool.max_pool( net, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME") for block_group in self._block_groups: net = block_group(net, is_training) if self._resnet_v2: net = self._final_batchnorm(net, is_training=is_training) net = jax.nn.relu(net) net = jnp.mean(net, axis=[1, 2]) return self._logits(net)