Beispiel #1
0
 def forward(self, x, is_training):
     # Block 1
     x = jax.nn.relu(self.conv1_1(x))
     x = self.bn1_1(x, is_training)
     x = jax.nn.relu(self.conv1_2(x))
     x = self.bn1_2(x, is_training)
     x = hk.max_pool(x, 2, 2, "SAME")
     if is_training:
         x = hk.dropout(hk.next_rng_key(), 0.2, x)
     # Block 2
     x = jax.nn.relu(self.conv2_1(x))
     x = self.bn2_1(x, is_training)
     x = jax.nn.relu(self.conv2_2(x))
     x = self.bn2_2(x, is_training)
     x = hk.max_pool(x, 2, 2, "SAME")
     if is_training:
         x = hk.dropout(hk.next_rng_key(), 0.3, x)
     # Block 3
     x = jax.nn.relu(self.conv3_1(x))
     x = self.bn3_1(x, is_training)
     x = jax.nn.relu(self.conv3_2(x))
     x = self.bn3_2(x, is_training)
     x = hk.max_pool(x, 2, 2, "SAME")
     if is_training:
         x = hk.dropout(hk.next_rng_key(), 0.4, x)
     # Linear part
     x = hk.Flatten()(x)
     x = jax.nn.relu(self.lin1(x))
     x = self.bn4(x, is_training)
     if is_training:
         x = hk.dropout(hk.next_rng_key(), 0.5, x)
     x = self.lin2(x)
     return x  # logits
 def __call__(self, x, is_training=True, return_metrics=False):
     """Return the output of the final layer without any [log-]softmax."""
     # Stem
     outputs = {}
     out = self.initial_conv(x)
     out = hk.max_pool(out,
                       window_shape=(1, 3, 3, 1),
                       strides=(1, 2, 2, 1),
                       padding='SAME')
     if return_metrics:
         outputs.update(base.signal_metrics(out, 0))
     # Blocks
     for i, block in enumerate(self.blocks):
         out, res_avg_var = block(out, is_training=is_training)
         if return_metrics:
             outputs.update(base.signal_metrics(out, i + 1))
             outputs[f'res_avg_var_{i}'] = res_avg_var
     # Final-conv->activation, pool, dropout, classify
     pool = jnp.mean(self.activation(out), [1, 2])
     outputs['pool'] = pool
     # Optionally apply dropout
     if self.drop_rate > 0.0 and is_training:
         pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
     outputs['logits'] = self.fc(pool)
     return outputs
Beispiel #3
0
    def __call__(self, inputs, is_training, final_endpoint='output'):
        self._final_endpoint = final_endpoint
        net = self._initial_conv(inputs)
        net = hk.max_pool(net,
                          window_shape=(1, 3, 3, 1),
                          strides=(1, 2, 2, 1),
                          padding='SAME')
        end_point = 'resnet_stem'
        if self._final_endpoint == end_point:
            return net

        for i_group, block_group in enumerate(self._block_groups):
            net = block_group(net, is_training=is_training)
            end_point = f'resnet_unit_{i_group}'
            if self._final_endpoint == end_point:
                return net

        end_point = 'last_conv'
        if self._final_endpoint == end_point:
            return net

        if self._normalize_fn is not None:
            net = self._normalize_fn(net, is_training=is_training)
            net = jax.nn.relu(net)

        # The actual representation
        net = jnp.mean(net, axis=[1, 2])

        assert self._final_endpoint == 'output'
        if self._num_classes is None:
            # If num_classes was None, we just return the output
            # of the last block, without fully connected layer.
            return net

        return self._logits_layer(net)
Beispiel #4
0
    def __call__(self, x):
        torso_out = x / 255.
        for i, (num_channels, num_blocks) in enumerate([(16, 2), (32, 2),
                                                        (32, 2)]):
            conv = hk.Conv2D(num_channels,
                             kernel_shape=[3, 3],
                             stride=[1, 1],
                             padding='SAME')
            torso_out = conv(torso_out)
            torso_out = hk.max_pool(
                torso_out,
                window_shape=[1, 3, 3, 1],
                strides=[1, 2, 2, 1],
                padding='SAME',
            )
            for j in range(num_blocks):
                block = ResidualBlock(num_channels,
                                      name='residual_{}_{}'.format(i, j))
                torso_out = block(torso_out)

        torso_out = jax.nn.relu(torso_out)
        torso_out = hk.Flatten()(torso_out)
        torso_out = hk.Linear(256)(torso_out)
        torso_out = jax.nn.relu(torso_out)
        return torso_out
 def __call__(self,
              inputs: jnp.ndarray,
              *,
              is_training: bool,
              test_local_stats: bool = False) -> jnp.ndarray:
     out = inputs
     for layer in self.layers:
         out = layer['conv'](out)
         if layer['batchnorm'] is not None:
             out = layer['batchnorm'](out, is_training, test_local_stats)
         out = jax.nn.relu(out)
         out = hk.max_pool(out,
                           window_shape=(1, 3, 3, 1),
                           strides=(1, 2, 2, 1),
                           padding='SAME')
     return out
Beispiel #6
0
    def __call__(self, inputs, is_training, test_local_stats=False):
        out = inputs
        out = self.initial_conv(out)
        if not self.resnet_v2:
            out = self.initial_batchnorm(out, is_training, test_local_stats)
            out = jax.nn.relu(out)

        out = hk.max_pool(out,
                          window_shape=(1, 3, 3, 1),
                          strides=(1, 2, 2, 1),
                          padding='SAME')

        for block_group in self.block_groups:
            out = block_group(out, is_training, test_local_stats)

        if self.resnet_v2:
            out = self.final_batchnorm(out, is_training, test_local_stats)
            out = jax.nn.relu(out)
        out = jnp.mean(out, axis=[1, 2])
        return out
Beispiel #7
0
 def __call__(self,
              x,
              is_training,
              test_local_stats=False,
              return_metrics=False):
     """Return the output of the final layer without any [log-]softmax."""
     outputs = {}
     # Stem
     out = self.initial_conv(x)
     if not self.preactivation:
         out = self.activation(
             self.initial_bn(out, is_training, test_local_stats))
     out = hk.max_pool(out,
                       window_shape=(1, 3, 3, 1),
                       strides=(1, 2, 2, 1),
                       padding='SAME')
     if return_metrics:
         outputs.update(base.signal_metrics(out, 0))
     # Blocks
     for i, block in enumerate(self.blocks):
         out, res_var = block(out, is_training, test_local_stats)
         if return_metrics:
             outputs.update(base.signal_metrics(out, i + 1))
             outputs[f'res_avg_var_{i}'] = res_var
     if self.preactivation:
         out = self.activation(
             self.final_bn(out, is_training, test_local_stats))
     # Pool, dropout, classify
     pool = jnp.mean(out, axis=[1, 2])
     # Return pool before dropout in case we want to regularize it separately.
     outputs['pool'] = pool
     # Optionally apply dropout
     if self.drop_rate > 0.0 and is_training:
         pool = hk.dropout(hk.next_rng_key(), self.drop_rate, pool)
     outputs['logits'] = self.fc(pool)
     return outputs
def downsample(x, factor):
    return hk.max_pool(x,
                       window_shape=(1, factor, factor, 1),
                       strides=(1, factor, factor, 1),
                       padding='VALID')