def __call__(self, x, is_training=True, return_metrics=False): """Return the output of the final layer without any [log-]softmax.""" # Stem outputs = {} out = self.initial_conv(x) out = hk.max_pool(out, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding='SAME') if return_metrics: outputs.update(base.signal_metrics(out, 0)) # Blocks for i, block in enumerate(self.blocks): out, res_avg_var = block(out, is_training=is_training) if return_metrics: outputs.update(base.signal_metrics(out, i + 1)) outputs[f'res_avg_var_{i}'] = res_avg_var # Final-conv->activation, pool, dropout, classify pool = jnp.mean(self.activation(out), [1, 2]) outputs['pool'] = pool # Optionally apply dropout if self.drop_rate > 0.0 and is_training: pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool) outputs['logits'] = self.fc(pool) return outputs
def __call__(self, x, is_training, test_local_stats=False, return_metrics=False): """Return the output of the final layer without any [log-]softmax.""" outputs = {} # Stem out = self.initial_conv(x) if not self.preactivation: out = self.activation( self.initial_bn(out, is_training, test_local_stats)) out = hk.max_pool(out, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding='SAME') if return_metrics: outputs.update(base.signal_metrics(out, 0)) # Blocks for i, block in enumerate(self.blocks): out, res_var = block(out, is_training, test_local_stats) if return_metrics: outputs.update(base.signal_metrics(out, i + 1)) outputs[f'res_avg_var_{i}'] = res_var if self.preactivation: out = self.activation( self.final_bn(out, is_training, test_local_stats)) # Pool, dropout, classify pool = jnp.mean(out, axis=[1, 2]) # Return pool before dropout in case we want to regularize it separately. outputs['pool'] = pool # Optionally apply dropout if self.drop_rate > 0.0 and is_training: pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool) outputs['logits'] = self.fc(pool) return outputs