Example #1
0
    def test_max_pool_same_padding_with_inferred_shapes(self):
        x = np.arange(6, dtype=jnp.float32)
        x = np.broadcast_to(x, (2, 3, 6))

        result = pool.max_pool(x, 3, 1, padding="SAME", channel_axis=None)

        np.testing.assert_equal(result.shape, x.shape)
Example #2
0
  def test_max_pool_same_padding(self):
    x = np.arange(6, dtype=jnp.float32)
    x = np.broadcast_to(x, (2, 3, 6))

    window_shape = [1, 3, 3]
    strides = [1, 1, 1]
    result = pool.max_pool(
        x, window_shape=window_shape, strides=strides, padding="SAME")

    np.testing.assert_equal(result.shape, x.shape)
Example #3
0
    def test_max_pool_basic_with_inferred_shapes(self):
        x = np.arange(6, dtype=jnp.float32).reshape([6, 1])
        x = np.broadcast_to(x, (2, 10, 6, 2))

        result = pool.max_pool(x, 2, 2, padding="VALID")

        ground_truth = np.asarray([1., 3., 5.]).reshape([3, 1])
        ground_truth = np.broadcast_to(ground_truth, (2, 5, 3, 2))

        np.testing.assert_equal(result, ground_truth)
Example #4
0
  def test_max_pool_basic(self):
    x = np.arange(6, dtype=jnp.float32).reshape([6, 1])
    x = np.broadcast_to(x, (2, 10, 6, 2))

    window_shape = [1, 2, 2, 1]
    result = pool.max_pool(
        x, window_shape=window_shape, strides=window_shape, padding="VALID")

    ground_truth = np.asarray([1., 3., 5.]).reshape([3, 1])
    ground_truth = np.broadcast_to(ground_truth, (2, 5, 3, 2))

    np.testing.assert_equal(result, ground_truth)
Example #5
0
    def test_avg_pool_same_padding(self):
        x = np.ones((2, 3, 6))

        window_shape = [1, 3, 3]
        strides = [1, 1, 1]
        result = pool.max_pool(x,
                               window_shape=window_shape,
                               strides=strides,
                               padding="SAME")

        np.testing.assert_equal(result.shape, x.shape)
        # Since x is constant, its avg value should be itself.
        np.testing.assert_equal(result, x)
Example #6
0
  def test_max_pool_overlapping_windows(self):
    x = np.arange(12, dtype=jnp.float32).reshape([6, 2])
    x = np.broadcast_to(x, (2, 10, 6, 2))

    window_shape = [1, 5, 3, 2]
    strides = [1, 1, 3, 2]
    result = pool.max_pool(
        x, window_shape=window_shape, strides=strides, padding="VALID")

    ground_truth = np.asarray([5., 11.,]).reshape([2, 1])
    ground_truth = np.broadcast_to(ground_truth, (2, 6, 2, 1))

    np.testing.assert_equal(result, ground_truth)
Example #7
0
    def test_max_pool_unbatched(self):
        x = np.arange(6, dtype=jnp.float32).reshape([6, 1])
        leading_dims = (2, 3)
        x = np.broadcast_to(x, leading_dims + (10, 6, 2))

        window_shape = [2, 2, 1]
        result = pool.max_pool(x,
                               window_shape=window_shape,
                               strides=window_shape,
                               padding="VALID")

        ground_truth = np.asarray([1., 3., 5.]).reshape([3, 1])
        ground_truth = np.broadcast_to(ground_truth, leading_dims + (5, 3, 2))

        np.testing.assert_equal(result, ground_truth)
Example #8
0
    def __call__(self, inputs, is_training):
        x = inputs
        x = self._initial_conv(x)
        if not self._resnet_v2:
            x = self._initial_batchnorm(x, is_training=is_training)
            x = jax.nn.relu(x)

        x = pool.max_pool(x,
                          window_shape=(1, 3, 3, 1),
                          strides=(1, 2, 2, 1),
                          padding="SAME")

        for block_group in self._block_groups:
            x = block_group(x, is_training)

        if self._resnet_v2:
            x = self._final_batchnorm(x, is_training=is_training)
            x = jax.nn.relu(x)
        x = jnp.mean(x, axis=[1, 2])
        return self._logits(x)
Example #9
0
File: resnet.py Project: ibab/haiku
  def __call__(self, inputs, is_training):
    net = inputs
    net = self._initial_conv(net)
    if not self._resnet_v2:
      net = self._initial_batchnorm(net, is_training=is_training)
      net = jax.nn.relu(net)

    net = pool.max_pool(
        net,
        window_shape=(1, 3, 3, 1),
        strides=(1, 2, 2, 1),
        padding="SAME")

    for block_group in self._block_groups:
      net = block_group(net, is_training)

    if self._resnet_v2:
      net = self._final_batchnorm(net, is_training=is_training)
      net = jax.nn.relu(net)
    net = jnp.mean(net, axis=[1, 2])
    return self._logits(net)