def __call__(self, inputs, is_training): initial_conv = conv.Conv2D(32, (3, 3), stride=2, padding="VALID", with_bias=self._with_bias) net = initial_conv(inputs) if self._use_bn: net = batch_norm.BatchNorm(create_scale=True, create_offset=True)(net, is_training) net = jax.nn.relu(net) for i in range(len(self._strides)): net = MobileNetV1Block(self._channels[i], self._strides[i], self._use_bn)(net, is_training) net = jnp.mean(net, axis=(1, 2)) net = reshape.Flatten()(net) net = basic.Linear(self._num_classes, name="logits")(net) return net
def test_flatten_invalid_preserve_dims(self): with self.assertRaisesRegex(ValueError, "Argument preserve_dims should be >= 1."): reshape.Flatten(preserve_dims=-1)
def test_flatten_nd(self): mod = reshape.Flatten(preserve_dims=2) x = jnp.zeros([2, 3]) y = mod(x) self.assertEqual(x.shape, y.shape)
def test_flatten_1d(self): mod = reshape.Flatten() x = jnp.zeros([10]) y = mod(x) self.assertEqual(x.shape, y.shape)
def f(): return reshape.Flatten(preserve_dims=2)(jnp.zeros([2, 3, 4, 5]))
def test_flatten_nd_out_negative(self): mod = reshape.Flatten(preserve_dims=-2) x = jnp.zeros([5, 2, 3]) y = mod(x) self.assertEqual(y.shape, (5, 6))