def apply(self,
            x,
            channels,
            strides=(1, 1),
            dropout_rate=0.0,
            norm_layer='group_norm',
            train=True):
    norm_layer_name = ''
    if norm_layer == 'batch_norm':
      norm_layer = nn.BatchNorm.partial(use_running_average=not train)
      norm_layer_name = 'bn'
    elif norm_layer == 'group_norm':
      norm_layer = nn.GroupNorm.partial(num_groups=16)
      norm_layer_name = 'gn'

    y = norm_layer(x, name=f'{norm_layer_name}1')
    y = jax.nn.relu(y)
    y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1')
    y = norm_layer(y, name=f'{norm_layer_name}2')
    y = jax.nn.relu(y)
    if dropout_rate > 0.0:
      y = nn.dropout(y, dropout_rate, deterministic=not train)
    y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2')

    # Apply an up projection in case of channel mismatch
    if (x.shape[-1] != channels) or strides != (1, 1):
      x = nn.Conv(x, channels, (3, 3), strides, padding='SAME')
    return x + y
Beispiel #2
0
    def apply(self,
              x,
              num_filters=64,
              block_sizes=(3, 4, 6, 3),
              train=True,
              block=BottleneckBlock,
              small_inputs=False):
        if small_inputs:
            x = nn.Conv(x,
                        num_filters,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        bias=False,
                        name="init_conv")
        else:
            x = nn.Conv(x,
                        num_filters,
                        kernel_size=(7, 7),
                        strides=(2, 2),
                        bias=False,
                        name="init_conv")
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         epsilon=1e-5,
                         name="init_bn")
        if not small_inputs:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
        for i, block_size in enumerate(block_sizes):
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                x = block(x, num_filters * 2**i, strides=strides, train=train)

        return x
 def apply(self,
           x,
           filters,
           strides=(1, 1),
           groups=1,
           base_width=64,
           train=True):
     needs_projection = x.shape[-1] != filters * 4 or strides != (1, 1)
     width = int(filters * (base_width / 64.)) * groups
     batch_norm = nn.BatchNorm.partial(use_running_average=not train,
                                       momentum=0.9,
                                       epsilon=1e-5)
     y = nn.Conv(x, width, (1, 1), (1, 1), bias=False, name='conv1')
     y = batch_norm(y, name='bn1')
     y = jax.nn.relu(y)
     y = nn.Conv(y,
                 width, (3, 3),
                 strides,
                 bias=False,
                 feature_group_count=groups,
                 name='conv2')
     y = batch_norm(y, name='bn2')
     y = jax.nn.relu(y)
     y = nn.Conv(y, filters * 4, (1, 1), (1, 1), bias=False, name='conv3')
     y = batch_norm(y, name='bn3', scale_init=initializers.zeros)
     if needs_projection:
         x = nn.Conv(x,
                     filters * 4, (1, 1),
                     strides,
                     bias=False,
                     name='proj_conv')
         x = batch_norm(x, name='proj_bn')
     return jax.nn.relu(x + y)
Beispiel #4
0
 def apply(self, x, inner_channels=8):
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False,
                    padding='SAME')
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False,
                    padding='SAME')
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False,
                    padding='SAME')
   return x
Beispiel #5
0
 def apply(self, x, num_cycles=3, inner_channels=8):
   x = Relaxation(x, inner_channels=inner_channels)
   if num_cycles > 0:
     x1 = nn.Conv(x, features=inner_channels, kernel_size=(3, 3),
                       bias=False,
                       strides=(2, 2),
                       padding='VALID')
     x1 = Cycle(x1, num_cycles=num_cycles-1, inner_channels=inner_channels)
     x1 = nn.Conv(x1, features=1, kernel_size=(3, 3), bias=False,
                      input_dilation=(2,2),padding=[(2, 2), (2, 2)])
     x = x + x1
     x = Relaxation(x, inner_channels=inner_channels)
   return x
 def apply(self, x, use_squeeze_excite = False):
   x = nn.Conv(x, features=8, kernel_size=(3, 3), padding="VALID")
   x = nn.relu(x)
   x = nn.Conv(x, features=16, kernel_size=(3, 3), padding="VALID")
   x = nn.relu(x)
   if use_squeeze_excite:
     x = SqueezeExciteLayer(x)
   x = nn.Conv(x, features=32, kernel_size=(3, 3), padding="VALID")
   x = nn.relu(x)
   if use_squeeze_excite:
     x = SqueezeExciteLayer(x)
   x = nn.Conv(x, features=1, kernel_size=(3, 3), padding="VALID")
   scores = nn.max_pool(x, window_shape=(8, 8), strides=(8, 8))[Ellipsis, 0]
   return scores
Beispiel #7
0
 def apply(self, x, num_cycles=3, inner_channels=8):
   x = NonLinearRelaxation(x, inner_channels=inner_channels)
   if num_cycles > 0:
     x1 = nn.Conv(x, features=inner_channels, kernel_size=(3, 3),
                       bias=False,
                       strides=(2, 2),
                       padding='VALID')
     x1 = NonLinearCycle(x1, num_cycles=num_cycles-1, inner_channels=inner_channels)
     x1 = nn.Conv(x1, features=1, kernel_size=(3, 3), bias=False,
                      input_dilation=(2,2),padding=[(2, 2), (2, 2)])
     x = x + x1
     #x = np.concatenate((x,x1), axis=3)
     #print(x.shape)
     x = NonLinearRelaxation(x, inner_channels=inner_channels)
   return x
Beispiel #8
0
 def apply(self, x):
     x = nn.Conv(x, features=32, kernel_size=(3, 3), name="conv")
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten.
     x = nn.Dense(x, 128, name="fc")
     return x
Beispiel #9
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              dropout_rate=0.0,
              train=True):

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = WideResnetGroup(x,
                            blocks_per_group,
                            16 * channel_multiplier,
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            32 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            64 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x
Beispiel #10
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              train=True):

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      16 * channel_multiplier,
                                      train=train)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      32 * channel_multiplier, (2, 2),
                                      train=train)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      64 * channel_multiplier, (2, 2),
                                      train=train)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x
Beispiel #11
0
  def apply(self, x, channels, strides=(1, 1), dropout_rate=0.0, train=True):
    batch_norm = nn.BatchNorm.partial(use_running_average=not train,
                                      momentum=0.9, epsilon=1e-5)

    y = batch_norm(x, name='bn1')
    y = jax.nn.relu(y)
    y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1')
    y = batch_norm(y, name='bn2')
    y = jax.nn.relu(y)
    if dropout_rate > 0.0:
      y = nn.dropout(y, dropout_rate, deterministic=not train)
    y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2')

    # Apply an up projection in case of channel mismatch
    if (x.shape[-1] != channels) or strides != (1, 1):
      x = nn.Conv(x, channels, (3, 3), strides, padding='SAME')
    return x + y
Beispiel #12
0
 def apply(self, x, inner_channels=8):
     x = NonLinearCycle(x, 4, inner_channels)
     x = nn.Conv(x,
                 features=1,
                 kernel_size=(3, 3),
                 bias=False,
                 padding='SAME')
     x = nn.relu(x)
     return x
Beispiel #13
0
    def apply(self, x, channels, strides=(1, 1), train=True):
        batch_norm = nn.BatchNorm.partial(use_running_average=not train,
                                          momentum=0.9,
                                          epsilon=1e-5)

        a = b = residual = x

        a = jax.nn.relu(a)
        a = nn.Conv(a,
                    channels, (3, 3),
                    strides,
                    padding='SAME',
                    name='conv_a_1')
        a = batch_norm(a, name='bn_a_1')
        a = jax.nn.relu(a)
        a = nn.Conv(a, channels, (3, 3), padding='SAME', name='conv_a_2')
        a = batch_norm(a, name='bn_a_2')

        b = jax.nn.relu(b)
        b = nn.Conv(b,
                    channels, (3, 3),
                    strides,
                    padding='SAME',
                    name='conv_b_1')
        b = batch_norm(b, name='bn_b_1')
        b = jax.nn.relu(b)
        b = nn.Conv(b, channels, (3, 3), padding='SAME', name='conv_b_2')
        b = batch_norm(b, name='bn_b_2')

        if train and not self.is_initializing():
            ab = shake.shake_shake_train(a, b)
        else:
            ab = shake.shake_shake_eval(a, b)

        # Apply an up projection in case of channel mismatch
        if (residual.shape[-1] != channels) or strides != (1, 1):
            residual = nn.Conv(residual,
                               channels, (3, 3),
                               strides,
                               padding='SAME',
                               name='conv_residual')
            residual = batch_norm(residual, name='bn_residual')

        return residual + ab
Beispiel #14
0
 def apply(self, inputs, output_dim, kernel_size=3, biases=True):
     output = nn.Conv(inputs,
                      features=output_dim,
                      kernel_size=(kernel_size, kernel_size),
                      strides=(1, 1),
                      padding='SAME',
                      bias=biases)
     output = sum([
         output[:, ::2, ::2, :], output[:, 1::2, ::2, :],
         output[:, ::2, 1::2, :], output[:, 1::2, 1::2, :]
     ]) / 4.
     return output
Beispiel #15
0
    def apply(self,
              x,
              channels,
              strides,
              prob,
              alpha_min,
              alpha_max,
              beta_min,
              beta_max,
              train=True):
        batch_norm = nn.BatchNorm.partial(use_running_average=not train,
                                          momentum=0.9,
                                          epsilon=1e-5)

        y = batch_norm(x, name='bn_1_pre')
        y = nn.Conv(y,
                    channels, (1, 1),
                    padding='SAME',
                    name='1x1_conv_contract')
        y = batch_norm(y, name='bn_1_post')
        y = jax.nn.relu(y)
        y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='3x3')
        y = batch_norm(y, name='bn_2')
        y = jax.nn.relu(y)
        y = nn.Conv(y,
                    channels * 4, (1, 1),
                    padding='SAME',
                    name='1x1_conv_expand')
        y = batch_norm(y, name='bn_3')

        if train:
            y = shake.shake_drop_train(y, prob, alpha_min, alpha_max, beta_min,
                                       beta_max)
        else:
            y = shake.shake_drop_eval(y, prob, alpha_min, alpha_max)

        x = shortcut(x, channels * 4, strides)
        return x + y
Beispiel #16
0
def ddpm_conv1x1(x,
                 out_planes,
                 stride=1,
                 bias=True,
                 dilation=1,
                 init_scale=1.):
    """1x1 convolution with DDPM initialization."""
    bias_init = jnn.initializers.zeros
    output = nn.Conv(x,
                     out_planes,
                     kernel_size=(1, 1),
                     strides=(stride, stride),
                     padding='SAME',
                     bias=bias,
                     kernel_dilation=(dilation, dilation),
                     kernel_init=default_init(init_scale),
                     bias_init=bias_init)
    return output
Beispiel #17
0
def ncsn_conv3x3(x,
                 out_planes,
                 stride=1,
                 bias=True,
                 dilation=1,
                 init_scale=1.):
    """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
    init_scale = 1e-10 if init_scale == 0 else init_scale
    kernel_init = jnn.initializers.variance_scaling(1 / 3 * init_scale,
                                                    'fan_in', 'uniform')
    kernel_shape = (3, 3) + (x.shape[-1], out_planes)
    bias_init = lambda key, shape: kernel_init(key, kernel_shape)[0, 0, 0, :]
    output = nn.Conv(x,
                     out_planes,
                     kernel_size=(3, 3),
                     strides=(stride, stride),
                     padding='SAME',
                     bias=bias,
                     kernel_dilation=(dilation, dilation),
                     kernel_init=kernel_init,
                     bias_init=bias_init)
    return output
 def apply(self, x, num_outputs, train=True):
     x = nn.Conv(x,
                 self.NUM_FILTERS, (7, 7), (2, 2),
                 bias=False,
                 name='init_conv')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      name='init_bn')
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(self.BLOCK_SIZES):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = BottleneckBlock(x,
                                 self.NUM_FILTERS * 2**i,
                                 strides=strides,
                                 groups=self.GROUPS,
                                 base_width=self.WIDTH_PER_GROUP,
                                 train=train)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x, num_outputs, name='clf')
     return x
Beispiel #19
0
    def apply(self, x, communication=Communication.NONE, train=True):
        """Forward pass."""
        batch_size = x.shape[0]

        if communication is Communication.SQUEEZE_EXCITE_X:
            x = sample_patches.SqueezeExciteLayer(x)
        # end if squeeze excite x

        d1 = nn.relu(
            nn.Conv(x,
                    128,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    bias=True,
                    name="down1"))
        d2 = nn.relu(
            nn.Conv(d1,
                    128,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    bias=True,
                    name="down2"))
        d3 = nn.relu(
            nn.Conv(d2,
                    128,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    bias=True,
                    name="down3"))

        if communication is Communication.SQUEEZE_EXCITE_D:
            d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")
            d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")
            d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")

            nd1 = d1_flatten.shape[1]
            nd2 = d2_flatten.shape[1]

            d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten],
                                         axis=1)

            num_channels = d_together.shape[-1]
            y = d_together.mean(axis=1)
            y = nn.Dense(y, features=num_channels // 4, bias=False)
            y = nn.relu(y)
            y = nn.Dense(y, features=num_channels, bias=False)
            y = nn.sigmoid(y)

            d_together = d_together * y[:, None, :]

            # split and reshape
            d1 = d_together[:, :nd1].reshape(d1.shape)
            d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape)
            d3 = d_together[:, nd1 + nd2:].reshape(d3.shape)

        elif communication is Communication.TRANSFORMER:
            d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")
            d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")
            d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")

            nd1 = d1_flatten.shape[1]
            nd2 = d2_flatten.shape[1]

            d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten],
                                         axis=1)

            positional_encodings = self.param(
                "scale_ratio_position_encodings",
                shape=(1, ) + d_together.shape[1:],
                initializer=jax.nn.initializers.normal(1. /
                                                       d_together.shape[-1]))
            d_together = transformer.Transformer(d_together +
                                                 positional_encodings,
                                                 num_layers=2,
                                                 num_heads=8,
                                                 is_training=train)

            # split and reshape
            d1 = d_together[:, :nd1].reshape(d1.shape)
            d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape)
            d3 = d_together[:, nd1 + nd2:].reshape(d3.shape)

        t1 = nn.Conv(d1,
                     6,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy1")
        t2 = nn.Conv(d2,
                     6,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy2")
        t3 = nn.Conv(d3,
                     9,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy3")

        raw_scores = (jnp.split(t1, 6, axis=-1) + jnp.split(t2, 6, axis=-1) +
                      jnp.split(t3, 9, axis=-1))

        # The following is for normalization.
        t = jnp.concatenate((jnp.reshape(
            t1, [batch_size, -1]), jnp.reshape(
                t2, [batch_size, -1]), jnp.reshape(t3, [batch_size, -1])),
                            axis=1)
        t_min = jnp.reshape(jnp.min(t, axis=-1), [batch_size, 1, 1, 1])
        t_max = jnp.reshape(jnp.max(t, axis=-1), [batch_size, 1, 1, 1])
        normalized_scores = zeroone(raw_scores, t_min, t_max)

        stats = {
            "scores": normalized_scores,
            "raw_scores": t,
        }
        # removes the split dimension. scores are now b x h' x w' shaped
        normalized_scores = [s.squeeze(-1) for s in normalized_scores]

        return normalized_scores, stats
  def apply(
      self,
      inputs,
      blocks_per_group,
      channel_multiplier,
      num_outputs,
      kernel_size=(3, 3),
      strides=None,
      maxpool=False,
      dropout_rate=0.0,
      dtype=jnp.float32,
      norm_layer='group_norm',
      train=True,
      return_activations=False,
      input_layer_key='input',
      has_discriminator=False,
      discriminator=False,
  ):

    norm_layer_name = ''
    if norm_layer == 'batch_norm':
      norm_layer = nn.BatchNorm.partial(use_running_average=not train)
      norm_layer_name = 'bn'
    elif norm_layer == 'group_norm':
      norm_layer = nn.GroupNorm.partial(num_groups=16)
      norm_layer_name = 'gn'

    layer_activations = collections.OrderedDict()
    input_is_set = False
    current_rep_key = 'input'
    if input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'init_conv'
    if input_is_set:
      x = nn.Conv(
          x,
          16,
          kernel_size=kernel_size,
          strides=strides,
          padding='SAME',
          name='init_conv')
      if maxpool:
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l1'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          16 * channel_multiplier,
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l1')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l2'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          32 * channel_multiplier, (2, 2),
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l2')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l3'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          64 * channel_multiplier, (2, 2),
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l3')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l4'
    if input_is_set:
      x = norm_layer(x, name=f'{norm_layer_name}')
      x = jax.nn.relu(x)
      x = nn.avg_pool(x, (8, 8))
      x = x.reshape((x.shape[0], -1))
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    # DANN module
    if has_discriminator:
      z = dann_utils.flip_grad_identity(x)
      z = nn.Dense(z, 2, name='disc_l1', bias=True)
      z = nn.relu(z)
      z = nn.Dense(z, 2, name='disc_l2', bias=True)

    current_rep_key = 'head'
    if input_is_set:
      x = nn.Dense(x, num_outputs, dtype=dtype, name='head')
    else:
      x = inputs
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

      logging.warn('Input was never used')

    outputs = x
    if return_activations:
      outputs = (x, layer_activations, rep_key)
      if discriminator and has_discriminator:
        outputs = outputs + (z,)
    else:
      del layer_activations
      if discriminator and has_discriminator:
        outputs = (x, z)
    if discriminator and (not has_discriminator):
      raise ValueError(
          'Incosistent values passed for discriminator and has_discriminator')
    return outputs
Beispiel #21
0
    def apply(self,
              x,
              num_outputs,
              pyramid_alpha=200,
              pyramid_depth=272,
              train=True):
        assert (pyramid_depth - 2) % 9 == 0

        # Shake-drop hyper-params
        mask_prob = 0.5
        alpha_min, alpha_max = (-1.0, 1.0)
        beta_min, beta_max = (0.0, 1.0)

        # Bottleneck network size
        blocks_per_group = (pyramid_depth - 2) // 9
        # See Eqn 2 in https://arxiv.org/abs/1610.02915
        num_channels = 16
        # N in https://arxiv.org/abs/1610.02915
        total_blocks = blocks_per_group * 3
        delta_channels = pyramid_alpha / total_blocks

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5,
                         name='init_bn')

        layer_num = 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels), (1, 1),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        assert layer_num - 1 == total_blocks
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5,
                         name='final_bn')
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x
Beispiel #22
0
    def apply(self,
              inputs,
              num_outputs,
              num_filters=64,
              num_layers=50,
              dropout_rate=0.0,
              input_dropout_rate=0.0,
              train=True,
              dtype=jnp.float32,
              head_bias_init=jnp.zeros,
              return_activations=False,
              input_layer_key='input',
              has_discriminator=False,
              discriminator=False):
        """Apply a ResNet network on the input.

    Args:
      inputs: jnp array; Inputs.
      num_outputs: int; Number of output units.
      num_filters: int; Determines base number of filters. Number of filters in
        block i is  num_filters * 2 ** i.
      num_layers: int; Number of layers (should be one of the predefined ones.)
      dropout_rate: float; Rate of dropping out the output of different hidden
        layers.
      input_dropout_rate: float; Rate of dropping out the input units.
      train: bool; Is train?
      dtype: jnp type; Type of the outputs.
      head_bias_init: fn(rng_key, shape)--> jnp array; Initializer for head bias
        parameters.
      return_activations: bool; If True hidden activation are also returned.
      input_layer_key: str; Determines where to plugin the input (this is to
        enable providing inputs to slices of the model). If `input_layer_key` is
        `layer_i` we assume the inputs are the activations of `layer_i` and pass
        them to `layer_{i+1}`.
      has_discriminator: bool; Whether the model should have discriminator
        layer.
      discriminator: bool; Whether we should return discriminator logits.

    Returns:
      Unnormalized Logits with shape `[bs, num_outputs]`,
      if return_activations:
        Logits, dict of hidden activations and the key to the representation(s)
        which will be used in as ``The Representation'', e.g., for computing
        losses.
    """
        if num_layers not in ResNet._block_size_options:
            raise ValueError('Please provide a valid number of layers')

        block_sizes = ResNet._block_size_options[num_layers]

        layer_activations = collections.OrderedDict()
        input_is_set = False
        current_rep_key = 'input'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True

        if input_is_set:
            # Input dropout
            x = nn.dropout(x, input_dropout_rate, deterministic=not train)
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        current_rep_key = 'init_conv'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key
        elif input_is_set:
            # First block
            x = nn.Conv(x,
                        num_filters, (7, 7), (2, 2),
                        padding=[(3, 3), (3, 3)],
                        bias=False,
                        dtype=dtype,
                        name='init_conv')
            x = nn.BatchNorm(x,
                             use_running_average=not train,
                             momentum=0.9,
                             epsilon=1e-5,
                             dtype=dtype,
                             name='init_bn')
            x = nn.relu(x)
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        # Residual blocks
        for i, block_size in enumerate(block_sizes):

            # Stage i (each stage contains blocks of the same size).
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                current_rep_key = f'block_{i + 1}+{j}'
                if input_layer_key == current_rep_key:
                    x = inputs
                    input_is_set = True
                    layer_activations[current_rep_key] = x
                    rep_key = current_rep_key
                elif input_is_set:
                    x = ResidualBlock(x,
                                      num_filters * 2**i,
                                      strides=strides,
                                      dropout_rate=dropout_rate,
                                      train=train,
                                      dtype=dtype,
                                      name=f'block_{i + 1}_{j}')
                    layer_activations[current_rep_key] = x
                    rep_key = current_rep_key

        current_rep_key = 'avg_pool'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key
        elif input_is_set:
            # Global Average Pool
            x = jnp.mean(x, axis=(1, 2))
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        # DANN module
        if has_discriminator:
            z = dann_utils.flip_grad_identity(x)
            z = nn.Dense(z, 2, name='disc_l1', bias=True)
            z = nn.relu(z)
            z = nn.Dense(z, 2, name='disc_l2', bias=True)

        current_rep_key = 'head'
        if input_layer_key == current_rep_key:
            x = inputs
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

            logging.warn('Input was never used')
        elif input_is_set:
            x = nn.Dense(x,
                         num_outputs,
                         dtype=dtype,
                         bias_init=head_bias_init,
                         name='head')

        # Make sure that the output is float32, even if our previous computations
        # are in float16, or other types.
        x = jnp.asarray(x, jnp.float32)

        outputs = x
        if return_activations:
            outputs = (x, layer_activations, rep_key)
            if discriminator and has_discriminator:
                outputs = outputs + (z, )
        else:
            del layer_activations
            if discriminator and has_discriminator:
                outputs = (x, z)
        if discriminator and (not has_discriminator):
            raise ValueError(
                'Incosistent values passed for discriminator and has_discriminator'
            )
        return outputs
Beispiel #23
0
 def apply(self, x, inner_channels=8):
   x = nn.Conv(x, features=1, kernel_size=(3, 3), bias=False,
                    strides=(2, 2),
                    padding='VALID')
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False,
                    padding='SAME')
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False,
                    padding='SAME')
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False,
                    padding='VALID', strides=(2,2))
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False,
                    padding='SAME')
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False,
                    padding='SAME')
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False,
                    padding='SAME')
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False,
                    input_dilation=(2,2),padding=[(2, 2), (2, 2)])
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False,
                    padding='SAME')
   x = nn.Conv(x, features=inner_channels, kernel_size=(3, 3), bias=False,
                    padding='SAME')
   x = nn.Conv(x, features=1, kernel_size=(3, 3), bias=False,
                    input_dilation=(2,2),padding=[(2, 2), (2, 2)])
   return x