def __call__(self, x, is_training=True, return_metrics=False):
     """Return the output of the final layer without any [log-]softmax."""
     # Stem
     outputs = {}
     out = self.initial_conv(x)
     out = hk.max_pool(out,
                       window_shape=(1, 3, 3, 1),
                       strides=(1, 2, 2, 1),
                       padding='SAME')
     if return_metrics:
         outputs.update(base.signal_metrics(out, 0))
     # Blocks
     for i, block in enumerate(self.blocks):
         out, res_avg_var = block(out, is_training=is_training)
         if return_metrics:
             outputs.update(base.signal_metrics(out, i + 1))
             outputs[f'res_avg_var_{i}'] = res_avg_var
     # Final-conv->activation, pool, dropout, classify
     pool = jnp.mean(self.activation(out), [1, 2])
     outputs['pool'] = pool
     # Optionally apply dropout
     if self.drop_rate > 0.0 and is_training:
         pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
     outputs['logits'] = self.fc(pool)
     return outputs
Beispiel #2
0
 def __call__(self,
              x,
              is_training,
              test_local_stats=False,
              return_metrics=False):
     """Return the output of the final layer without any [log-]softmax."""
     outputs = {}
     # Stem
     out = self.initial_conv(x)
     if not self.preactivation:
         out = self.activation(
             self.initial_bn(out, is_training, test_local_stats))
     out = hk.max_pool(out,
                       window_shape=(1, 3, 3, 1),
                       strides=(1, 2, 2, 1),
                       padding='SAME')
     if return_metrics:
         outputs.update(base.signal_metrics(out, 0))
     # Blocks
     for i, block in enumerate(self.blocks):
         out, res_var = block(out, is_training, test_local_stats)
         if return_metrics:
             outputs.update(base.signal_metrics(out, i + 1))
             outputs[f'res_avg_var_{i}'] = res_var
     if self.preactivation:
         out = self.activation(
             self.final_bn(out, is_training, test_local_stats))
     # Pool, dropout, classify
     pool = jnp.mean(out, axis=[1, 2])
     # Return pool before dropout in case we want to regularize it separately.
     outputs['pool'] = pool
     # Optionally apply dropout
     if self.drop_rate > 0.0 and is_training:
         pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
     outputs['logits'] = self.fc(pool)
     return outputs