def call(self, x: jnp.ndarray): x0 = x x = nn.Conv2D(self.n_filters, (1, 1), with_bias=False, dtype=self.dtype)(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) x = jax.nn.relu(x) x = nn.Conv2D( self.n_filters, (3, 3), with_bias=False, stride=self.strides, dtype=self.dtype, )(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) x = jax.nn.relu(x) x = nn.Conv2D(self.n_filters * 4, (1, 1), with_bias=False, dtype=self.dtype)(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5, scale_init=jnp.zeros)(x) if x0.shape != x.shape: x0 = nn.Conv2D( self.n_filters * 4, (1, 1), with_bias=False, stride=self.strides, dtype=self.dtype, )(x0) x0 = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x0) return jax.nn.relu(x0 + x)
def call(self, x: jnp.ndarray): x = nn.Conv2D( 64, (7, 7) if not self.lowres else (3, 3), stride=(2, 2) if not self.lowres else (1, 1), padding="SAME", with_bias=False, dtype=self.dtype, )(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) x = module.to_module(jax.nn.relu)()(x) if not self.lowres: x = nn.MaxPool(window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME")(x) for i, block_size in enumerate(self.stages): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = self.block_type(64 * 2**i, strides=strides, dtype=self.dtype)(x) GAP = lambda x: jnp.mean(x, axis=(1, 2)) x = module.to_module(GAP)(name="global_average_pooling")(x) x = nn.Linear(1000, dtype=self.dtype)(x) to_float32 = lambda x: jnp.asarray(x, jnp.float32) x = module.to_module(to_float32)(name="to_float32")(x) return x