def forward(self, x, is_training): # Block 1 x = jax.nn.relu(self.conv1_1(x)) x = self.bn1_1(x, is_training) x = jax.nn.relu(self.conv1_2(x)) x = self.bn1_2(x, is_training) x = hk.max_pool(x, 2, 2, "SAME") if is_training: x = hk.dropout(hk.next_rng_key(), 0.2, x) # Block 2 x = jax.nn.relu(self.conv2_1(x)) x = self.bn2_1(x, is_training) x = jax.nn.relu(self.conv2_2(x)) x = self.bn2_2(x, is_training) x = hk.max_pool(x, 2, 2, "SAME") if is_training: x = hk.dropout(hk.next_rng_key(), 0.3, x) # Block 3 x = jax.nn.relu(self.conv3_1(x)) x = self.bn3_1(x, is_training) x = jax.nn.relu(self.conv3_2(x)) x = self.bn3_2(x, is_training) x = hk.max_pool(x, 2, 2, "SAME") if is_training: x = hk.dropout(hk.next_rng_key(), 0.4, x) # Linear part x = hk.Flatten()(x) x = jax.nn.relu(self.lin1(x)) x = self.bn4(x, is_training) if is_training: x = hk.dropout(hk.next_rng_key(), 0.5, x) x = self.lin2(x) return x # logits
def __call__(self, x, is_training=True, return_metrics=False): """Return the output of the final layer without any [log-]softmax.""" # Stem outputs = {} out = self.initial_conv(x) out = hk.max_pool(out, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding='SAME') if return_metrics: outputs.update(base.signal_metrics(out, 0)) # Blocks for i, block in enumerate(self.blocks): out, res_avg_var = block(out, is_training=is_training) if return_metrics: outputs.update(base.signal_metrics(out, i + 1)) outputs[f'res_avg_var_{i}'] = res_avg_var # Final-conv->activation, pool, dropout, classify pool = jnp.mean(self.activation(out), [1, 2]) outputs['pool'] = pool # Optionally apply dropout if self.drop_rate > 0.0 and is_training: pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool) outputs['logits'] = self.fc(pool) return outputs
def __call__(self, inputs, is_training, final_endpoint='output'): self._final_endpoint = final_endpoint net = self._initial_conv(inputs) net = hk.max_pool(net, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding='SAME') end_point = 'resnet_stem' if self._final_endpoint == end_point: return net for i_group, block_group in enumerate(self._block_groups): net = block_group(net, is_training=is_training) end_point = f'resnet_unit_{i_group}' if self._final_endpoint == end_point: return net end_point = 'last_conv' if self._final_endpoint == end_point: return net if self._normalize_fn is not None: net = self._normalize_fn(net, is_training=is_training) net = jax.nn.relu(net) # The actual representation net = jnp.mean(net, axis=[1, 2]) assert self._final_endpoint == 'output' if self._num_classes is None: # If num_classes was None, we just return the output # of the last block, without fully connected layer. return net return self._logits_layer(net)
def __call__(self, x): torso_out = x / 255. for i, (num_channels, num_blocks) in enumerate([(16, 2), (32, 2), (32, 2)]): conv = hk.Conv2D(num_channels, kernel_shape=[3, 3], stride=[1, 1], padding='SAME') torso_out = conv(torso_out) torso_out = hk.max_pool( torso_out, window_shape=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', ) for j in range(num_blocks): block = ResidualBlock(num_channels, name='residual_{}_{}'.format(i, j)) torso_out = block(torso_out) torso_out = jax.nn.relu(torso_out) torso_out = hk.Flatten()(torso_out) torso_out = hk.Linear(256)(torso_out) torso_out = jax.nn.relu(torso_out) return torso_out
def __call__(self, inputs: jnp.ndarray, *, is_training: bool, test_local_stats: bool = False) -> jnp.ndarray: out = inputs for layer in self.layers: out = layer['conv'](out) if layer['batchnorm'] is not None: out = layer['batchnorm'](out, is_training, test_local_stats) out = jax.nn.relu(out) out = hk.max_pool(out, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding='SAME') return out
def __call__(self, inputs, is_training, test_local_stats=False): out = inputs out = self.initial_conv(out) if not self.resnet_v2: out = self.initial_batchnorm(out, is_training, test_local_stats) out = jax.nn.relu(out) out = hk.max_pool(out, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding='SAME') for block_group in self.block_groups: out = block_group(out, is_training, test_local_stats) if self.resnet_v2: out = self.final_batchnorm(out, is_training, test_local_stats) out = jax.nn.relu(out) out = jnp.mean(out, axis=[1, 2]) return out
def __call__(self, x, is_training, test_local_stats=False, return_metrics=False): """Return the output of the final layer without any [log-]softmax.""" outputs = {} # Stem out = self.initial_conv(x) if not self.preactivation: out = self.activation( self.initial_bn(out, is_training, test_local_stats)) out = hk.max_pool(out, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding='SAME') if return_metrics: outputs.update(base.signal_metrics(out, 0)) # Blocks for i, block in enumerate(self.blocks): out, res_var = block(out, is_training, test_local_stats) if return_metrics: outputs.update(base.signal_metrics(out, i + 1)) outputs[f'res_avg_var_{i}'] = res_var if self.preactivation: out = self.activation( self.final_bn(out, is_training, test_local_stats)) # Pool, dropout, classify pool = jnp.mean(out, axis=[1, 2]) # Return pool before dropout in case we want to regularize it separately. outputs['pool'] = pool # Optionally apply dropout if self.drop_rate > 0.0 and is_training: pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool) outputs['logits'] = self.fc(pool) return outputs
def downsample(x, factor): return hk.max_pool(x, window_shape=(1, factor, factor, 1), strides=(1, factor, factor, 1), padding='VALID')