Beispiel #1
0
 def apply(self, x, reduction=16):
     num_channels = x.shape[-1]
     y = x.mean(axis=(1, 2))
     y = nn.Dense(y, features=num_channels // reduction, bias=False)
     y = nn.relu(y)
     y = nn.Dense(y, features=num_channels, bias=False)
     y = nn.sigmoid(y)
     return x * y[:, None, None, :]
Beispiel #2
0
    def get_logits(code_embeddings, length):
      # code_embeddings.shape: length, emb_dim

      initial_carry_e = encoder.initialize_carry(
          jax.random.PRNGKey(0), encoder_cells, (), emb_dim)

      def apply_encoder(carry, inp):
        i = carry[1]
        c1, o1 = encoder(carry[0], inp)
        return jax.tree_multimap(
            lambda x_new, x_old: jnp.where(i < length, x_new, x_old),
            ((c1, i+1), o1),
            (carry, inp)
        )

      (encoder_state, unused_i), unused_outputs = (
          jax.lax.scan(
              apply_encoder,
              (initial_carry_e, 0),
              code_embeddings
          )
      )

      decoder_inputs = jnp.zeros((output_length, emb_dim))
      unused_carry, decoder_outputs = jax.lax.scan(
          decoder, encoder_state, decoder_inputs)

      logits = nn.Dense(
          decoder_outputs,
          output_token_vocabulary_size,
          kernel_init=nn.initializers.normal(
              stddev=config.initialization.maxval),
          bias_init=nn.initializers.zeros,
          name='output_layer')
      return logits
Beispiel #3
0
  def apply(self,
            x,
            act,
            normalize,
            temb=None,
            out_ch=None,
            conv_shortcut=False,
            dropout=0.1,
            train=True,
            skip_rescale=False,
            init_scale=0.):
    B, H, W, C = x.shape
    out_ch = out_ch if out_ch else C
    h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32)))
    h = conv3x3(h, out_ch)
    # Add bias to each feature map conditioned on the time embedding
    if temb is not None:
      h += nn.Dense(
          act(temb), out_ch, kernel_init=default_init())[:, None, None, :]

    h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32)))
    h = nn.dropout(h, dropout, deterministic=not train)
    h = conv3x3(h, out_ch, init_scale=init_scale)
    if C != out_ch:
      if conv_shortcut:
        x = conv3x3(x, out_ch)
      else:
        x = NIN(x, out_ch)

    if not skip_rescale:
      return x + h
    else:
      return (x + h) / np.sqrt(2.)
Beispiel #4
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              dropout_rate=0.0,
              train=True):

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = WideResnetGroup(x,
                            blocks_per_group,
                            16 * channel_multiplier,
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            32 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            64 * channel_multiplier, (2, 2),
                            dropout_rate=dropout_rate,
                            train=train)
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x
Beispiel #5
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              train=True):

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      16 * channel_multiplier,
                                      train=train)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      32 * channel_multiplier, (2, 2),
                                      train=train)
        x = WideResnetShakeShakeGroup(x,
                                      blocks_per_group,
                                      64 * channel_multiplier, (2, 2),
                                      train=train)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x
Beispiel #6
0
 def apply(self, x):
     x = nn.Conv(x, features=32, kernel_size=(3, 3), name="conv")
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten.
     x = nn.Dense(x, 128, name="fc")
     return x
Beispiel #7
0
 def apply(self,
           x,
           act,
           normalize,
           temb=None,
           out_ch=None,
           conv_shortcut=False,
           dropout=0.5,
           train=True):
     B, H, W, C = x.shape
     out_ch = out_ch if out_ch else C
     h = act(normalize(x))
     h = ddpm_conv3x3(h, out_ch)
     # Add bias to each feature map conditioned on the time embedding
     if temb is not None:
         h += nn.Dense(act(temb), out_ch,
                       kernel_init=default_init())[:, None, None, :]
     h = act(normalize(h))
     h = nn.dropout(h, dropout, deterministic=not train)
     h = ddpm_conv3x3(h, out_ch, init_scale=0.)
     if C != out_ch:
         if conv_shortcut:
             x = ddpm_conv3x3(x, out_ch)
         else:
             x = NIN(x, out_ch)
     return x + h
Beispiel #8
0
 def apply(self,
           inputs,
           mlp_dim,
           out_dim=None,
           dropout_rate=0.1,
           deterministic=False,
           kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.normal(stddev=1e-6)):
   """Applies Transformer MlpBlock module."""
   actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
   x = nn.Dense(inputs, mlp_dim, kernel_init=kernel_init, bias_init=bias_init)
   x = nn.gelu(x)
   x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
   output = nn.Dense(
       x, actual_out_dim, kernel_init=kernel_init, bias_init=bias_init)
   output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic)
   return output
Beispiel #9
0
  def apply(self,
            x,
            act,
            normalize,
            up=False,
            down=False,
            temb=None,
            out_ch=None,
            dropout=0.1,
            fir=False,
            fir_kernel=[1, 3, 3, 1],
            train=True,
            skip_rescale=True,
            init_scale=0.):
    B, H, W, C = x.shape
    out_ch = out_ch if out_ch else C
    h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32)))

    if up:
      if fir:
        h = up_or_down_sampling.upsample_2d(h, fir_kernel, factor=2)
        x = up_or_down_sampling.upsample_2d(x, fir_kernel, factor=2)
      else:
        h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
        x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
    elif down:
      if fir:
        h = up_or_down_sampling.downsample_2d(h, fir_kernel, factor=2)
        x = up_or_down_sampling.downsample_2d(x, fir_kernel, factor=2)
      else:
        h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
        x = up_or_down_sampling.naive_downsample_2d(x, factor=2)

    h = conv3x3(h, out_ch)
    # Add bias to each feature map conditioned on the time embedding
    if temb is not None:
      h += nn.Dense(
          act(temb), out_ch, kernel_init=default_init())[:, None, None, :]

    h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32)))
    h = nn.dropout(h, dropout, deterministic=not train)
    h = conv3x3(h, out_ch, init_scale=init_scale)
    if C != out_ch or up or down:
      x = conv1x1(x, out_ch)

    if not skip_rescale:
      return x + h
    else:
      return (x + h) / np.sqrt(2.)
 def apply(self, x, num_outputs, train=True):
     x = nn.Conv(x,
                 self.NUM_FILTERS, (7, 7), (2, 2),
                 bias=False,
                 name='init_conv')
     x = nn.BatchNorm(x,
                      use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      name='init_bn')
     x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
     for i, block_size in enumerate(self.BLOCK_SIZES):
         for j in range(block_size):
             strides = (2, 2) if i > 0 and j == 0 else (1, 1)
             x = BottleneckBlock(x,
                                 self.NUM_FILTERS * 2**i,
                                 strides=strides,
                                 groups=self.GROUPS,
                                 base_width=self.WIDTH_PER_GROUP,
                                 train=train)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(x, num_outputs, name='clf')
     return x
Beispiel #11
0
    def apply(self,
              x,
              *,
              self_attention_module,
              dim_intermediate,
              is_training,
              dropout_rate=0.1,
              use_pre_layernorm=False,
              layernorm_epsilon=1e-6,
              with_aux_outputs=True):
        """Compute self-attention with a feed-forward network on top.

    Args:
      x: Input representations.
      self_attention_module: Self-Attention layer.
      dim_intermediate: Size of the intermediate layer of the feed forward.
      is_training: Wether to enable dropout.
      dropout_rate: Dropout probability.
      use_pre_layernorm: Use pre layer norm from
        https://arxiv.org/abs/2002.04745.
      layernorm_epsilon: Epsilon parameter for all the layer norms.
      with_aux_outputs: Whether the self_attention_module has an aux output.

    Returns:
      New representations in a jnp.array of same shape as `x`.
    """
        dim_hidden = x.shape[-1]
        use_pre_ln = use_pre_layernorm
        use_post_ln = not use_pre_ln

        def apply_ln_if(pred, x, name):
            if pred:
                return nn.LayerNorm(x, epsilon=layernorm_epsilon, name=name)
            else:
                return x

        # attention
        x = apply_ln_if(use_pre_ln, x, "ln_pre_att")
        x_att = self_attention_module(x)
        if with_aux_outputs:
            x_att, output_aux = x_att

        # dropout norm and add
        x_att = nn.dropout(x_att, dropout_rate, deterministic=not is_training)
        x = x + x_att
        x = apply_ln_if(use_post_ln, x, "ln_post_att")

        # feed forward
        x_ffn = x
        x_ffn = apply_ln_if(use_pre_ln, x, "ln_pre_ffn")
        x_ffn = nn.Dense(x_ffn, dim_intermediate, name="ff_1")
        x_ffn = jax.nn.relu(x_ffn)
        x_ffn = nn.Dense(x_ffn, dim_hidden, name="ff_2")

        # dropout norm and add
        x_ffn = nn.dropout(x_ffn, dropout_rate, deterministic=not is_training)
        x = x + x_ffn
        x = apply_ln_if(use_post_ln, x, "ln_post_ffn")

        if with_aux_outputs:
            output = x, output_aux
        else:
            output = x
        return output
Beispiel #12
0
    def apply(self, x, labels, y=None, config=None, train=True):
        # config parsing
        nf = config.model.nf
        act = get_act(config)
        normalize = get_normalization(config)
        sigmas = utils.get_sigmas(config)

        nf = config.model.nf
        ch_mult = config.model.ch_mult
        num_res_blocks = config.model.num_res_blocks
        attn_resolutions = config.model.attn_resolutions
        dropout = config.model.dropout
        resamp_with_conv = config.model.resamp_with_conv
        num_resolutions = len(ch_mult)

        conditional = config.model.conditional  # noise-conditional
        fir = config.model.fir
        fir_kernel = config.model.fir_kernel
        skip_rescale = config.model.skip_rescale
        resblock_type = config.model.resblock_type
        progressive = config.model.progressive
        progressive_input = config.model.progressive_input
        init_scale = config.model.init_scale
        assert progressive.lower() in ['none', 'output_skip', 'residual']
        assert config.model.embedding_type.lower() in [
            'gaussian', 'positional'
        ]
        combine_method = config.model.progressive_combine
        combiner = Combine.partial(method=combine_method)

        # timestep/noise_level embedding
        if config.model.embedding_type == 'gaussian':
            # Gaussian Fourier features embeddings.
            used_sigmas = sigmas[labels]
            temb = layersv3.GaussianFourierProjection(
                jnp.log(used_sigmas),
                embedding_size=nf,
                scale=config.model.fourier_scale)
        elif config.model.embedding_type == 'positional':
            # Sinusoidal positional embeddings.
            timesteps = labels
            temb = layers.get_timestep_embedding(timesteps, nf)
        else:
            raise ValueError(
                f'embedding type {config.model.embedding_type} unknown.')

        temb = nn.Dense(temb, nf * 4, kernel_init=default_initializer())
        temb = nn.Dense(act(temb), nf * 4, kernel_init=default_initializer())

        if y is not None:  # class-conditional image generation
            class_embed = nn.Embed(y, config.data.num_classes, nf * 4)
            class_embed = nn.Dense(class_embed,
                                   nf * 4,
                                   kernel_init=default_initializer())
            class_embed = nn.Dense(act(class_embed),
                                   nf * 4,
                                   kernel_init=default_initializer())

            temb += class_embed

        AttnBlock = layersv3.AttnBlockv3.partial(normalize=normalize,
                                                 init_scale=init_scale,
                                                 skip_rescale=skip_rescale)

        Upsample = layersv3.Upsample.partial(with_conv=resamp_with_conv,
                                             fir=fir,
                                             fir_kernel=fir_kernel)
        if progressive == 'output_skip':
            pyramid_upsample = layersv3.Upsample.partial(fir=fir,
                                                         fir_kernel=fir_kernel,
                                                         with_conv=False)
        elif progressive == 'residual':
            pyramid_upsample = layersv3.Upsample.partial(fir=fir,
                                                         fir_kernel=fir_kernel,
                                                         with_conv=True)

        Downsample = layersv3.Downsample.partial(with_conv=resamp_with_conv,
                                                 fir=fir,
                                                 fir_kernel=fir_kernel)

        if progressive_input == 'input_skip':
            pyramid_downsample = layersv3.Downsample.partial(
                fir=fir, fir_kernel=fir_kernel, with_conv=False)
        elif progressive_input == 'residual':
            pyramid_downsample = layersv3.Downsample.partial(
                fir=fir, fir_kernel=fir_kernel, with_conv=True)

        if resblock_type == 'ddpm':
            ResnetBlock = ResnetBlockDDPM.partial(
                act=act,
                normalize=normalize,
                dropout=dropout,
                temb=temb if conditional else None,
                train=train,
                init_scale=init_scale,
                skip_rescale=skip_rescale)

        elif resblock_type == 'biggan':
            ResnetBlock = ResnetBlockBigGAN.partial(
                act=act,
                normalize=normalize,
                temb=temb if conditional else None,
                train=train,
                dropout=dropout,
                fir=fir,
                fir_kernel=fir_kernel,
                init_scale=init_scale,
                skip_rescale=skip_rescale)

        else:
            raise ValueError(f'resblock_type {resblock_type} unrecognized.')

        if not config.data.centered:
            # If input data is in [0, 1]
            x = 2 * x - 1.

        # Downsampling block

        input_pyramid = None
        if progressive_input != 'none':
            input_pyramid = x

        hs = [conv3x3(x, nf)]
        for i_level in range(num_resolutions):
            # Residual blocks for this resolution
            for i_block in range(num_res_blocks):
                h = ResnetBlock(hs[-1], out_ch=nf * ch_mult[i_level])
                if h.shape[1] in attn_resolutions:
                    h = AttnBlock(h)
                hs.append(h)

            if i_level != num_resolutions - 1:
                if resblock_type == 'ddpm':
                    h = Downsample(hs[-1])
                else:
                    h = ResnetBlock(hs[-1], down=True)

                if progressive_input == 'input_skip':
                    input_pyramid = pyramid_downsample(input_pyramid)
                    h = combiner(input_pyramid, h)

                elif progressive_input == 'residual':
                    input_pyramid = pyramid_downsample(input_pyramid,
                                                       out_ch=h.shape[-1])
                    if skip_rescale:
                        input_pyramid = (input_pyramid + h) / np.sqrt(2.)
                    else:
                        input_pyramid = input_pyramid + h
                    h = input_pyramid

                hs.append(h)

        h = hs[-1]
        h = ResnetBlock(h)
        h = AttnBlock(h)
        h = ResnetBlock(h)

        pyramid = None

        # Upsampling block
        for i_level in reversed(range(num_resolutions)):
            for i_block in range(num_res_blocks + 1):
                h = ResnetBlock(jnp.concatenate([h, hs.pop()], axis=-1),
                                out_ch=nf * ch_mult[i_level])

            if h.shape[1] in attn_resolutions:
                h = AttnBlock(h)

            if progressive != 'none':
                if i_level == num_resolutions - 1:
                    if progressive == 'output_skip':
                        pyramid = conv3x3(act(
                            normalize(h, num_groups=min(h.shape[-1] // 4,
                                                        32))),
                                          x.shape[-1],
                                          bias=True,
                                          init_scale=init_scale)
                    elif progressive == 'residual':
                        pyramid = conv3x3(act(
                            normalize(h, num_groups=min(h.shape[-1] // 4,
                                                        32))),
                                          h.shape[-1],
                                          bias=True)
                    else:
                        raise ValueError(f'{progressive} is not a valid name.')
                else:
                    if progressive == 'output_skip':
                        pyramid = pyramid_upsample(pyramid)
                        pyramid = pyramid + conv3x3(act(
                            normalize(h, num_groups=min(h.shape[-1] // 4,
                                                        32))),
                                                    x.shape[-1],
                                                    bias=True,
                                                    init_scale=init_scale)
                    elif progressive == 'residual':
                        pyramid = pyramid_upsample(pyramid, out_ch=h.shape[-1])
                        if skip_rescale:
                            pyramid = (pyramid + h) / np.sqrt(2.)
                        else:
                            pyramid = pyramid + h
                        h = pyramid
                    else:
                        raise ValueError(f'{progressive} is not a valid name')

            if i_level != 0:
                if resblock_type == 'ddpm':
                    h = Upsample(h)
                else:
                    h = ResnetBlock(h, up=True)

        assert not hs

        if progressive == 'output_skip':
            h = pyramid
        else:
            h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32)))
            h = conv3x3(h, x.shape[-1], init_scale=init_scale)

        if config.model.scale_by_sigma:
            used_sigmas = sigmas[labels].reshape(
                (x.shape[0], *([1] * len(x.shape[1:]))))
            h = h / used_sigmas

        return h
Beispiel #13
0
    def apply(self,
              x,
              *,
              train,
              num_classes,
              block_class=BottleneckResNetImageNetBlock,
              stage_sizes,
              width_factor=1,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              bias_scale=0.0,
              weight_norm='none',
              compensate_padding=True,
              softplus_scale=None,
              no_head=False,
              zero_inits=True):
        """Construct ResNet V1 with `num_classes` outputs."""
        self._stage_sizes = stage_sizes
        if std_penalty_mult > 0:
            raise NotImplementedError(
                'std_penalty_mult not supported for ResNetImageNet')

        width = 64 * width_factor

        # Root block.
        activation_f = get_activation_f(activation_f, train, softplus_scale,
                                        bias_scale)
        norm = get_norm(activation_f, normalization, train)
        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)
        x = conv(x,
                 width,
                 kernel_size=(7, 7),
                 strides=(2, 2),
                 name='init_conv')
        x = norm(x, name='init_bn')

        if compensate_padding:
            # NOTE: this leads to lower performance.
            x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding='SAME')
        else:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

        # Stages.
        for i, stage_size in enumerate(stage_sizes):
            x = ResNetStage(
                x,
                stage_size,
                filters=width * 2**i,
                block_class=block_class,
                first_block_strides=(1, 1) if i == 0 else (2, 2),
                train=train,
                name=f'stage{i + 1}',
                conv=conv,
                norm=norm,
                activation_f=activation_f,
                use_residual=use_residual,
                zero_inits=zero_inits,
            )

        if not no_head:
            # Head.
            x = jnp.mean(x, axis=(1, 2))
            x = nn.Dense(x,
                         num_classes,
                         kernel_init=nn.initializers.zeros
                         if zero_inits else nn.initializers.lecun_normal(),
                         name='head')
        return x, 0, {}
Beispiel #14
0
    def apply(self,
              x,
              blocks_per_group,
              channel_multiplier,
              num_outputs,
              dropout_rate=0.0,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              train=True,
              bias_scale=0.0,
              weight_norm='none',
              no_head=False,
              compensate_padding=True,
              softplus_scale=None):

        penalty = 0

        activation_f = get_activation_f(activation_f, train, softplus_scale,
                                        bias_scale)
        norm = get_norm(activation_f, normalization, train)
        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)
        x = conv(x,
                 16 * channel_multiplier, (3, 3),
                 padding='SAME',
                 name='init_conv')
        x, g_penalty = WideResnetGroup(x,
                                       blocks_per_group,
                                       16 * channel_multiplier,
                                       dropout_rate=dropout_rate,
                                       normalization=normalization,
                                       activation_f=activation_f,
                                       std_penalty_mult=std_penalty_mult,
                                       use_residual=use_residual,
                                       train=train,
                                       bias_scale=bias_scale,
                                       weight_norm=weight_norm)
        penalty += g_penalty
        x, g_penalty = WideResnetGroup(x,
                                       blocks_per_group,
                                       32 * channel_multiplier, (2, 2),
                                       dropout_rate=dropout_rate,
                                       normalization=normalization,
                                       activation_f=activation_f,
                                       std_penalty_mult=std_penalty_mult,
                                       use_residual=use_residual,
                                       train=train,
                                       bias_scale=bias_scale,
                                       weight_norm=weight_norm)
        penalty += g_penalty
        x, g_penalty = WideResnetGroup(x,
                                       blocks_per_group,
                                       64 * channel_multiplier, (2, 2),
                                       dropout_rate=dropout_rate,
                                       normalization=normalization,
                                       activation_f=activation_f,
                                       std_penalty_mult=std_penalty_mult,
                                       use_residual=use_residual,
                                       train=train,
                                       bias_scale=bias_scale,
                                       weight_norm=weight_norm)
        penalty += g_penalty

        x = norm(x, name='final_norm')
        if std_penalty_mult > 0:
            penalty += std_penalty(x)
        if not no_head:
            x = activation_f(x, features=x.shape[-1])
            x = nn.avg_pool(x, (8, 8))
            x = x.reshape((x.shape[0], -1))
            x = nn.Dense(x, num_outputs)
        return x, penalty, {}
Beispiel #15
0
    def apply(self,
              x,
              depth,
              num_outputs,
              dropout_rate=0.0,
              normalization='bn',
              activation_f=None,
              std_penalty_mult=0,
              use_residual=1,
              train=True,
              bias_scale=0.0,
              weight_norm='none',
              filters=16,
              no_head=False,
              report_metrics=False,
              benchmark='cifar10',
              compensate_padding=True,
              softplus_scale=None):

        bn_index = iter(range(1000))
        conv_index = iter(range(1000))
        summaries = {}
        summary_ind = [0]

        def add_summary(name, val):
            """Summarize statistics of tensor."""
            if report_metrics:
                assert val.ndim == 4, (
                    'Assuming 4D inputs with channels last, got %s' %
                    str(val.shape))
                assert val.shape[1] == val.shape[
                    2], 'Assuming 4D inputs with channels last'
                summaries['%s_%d_mean_abs' %
                          (name, summary_ind[0] // 2)] = jnp.mean(
                              jnp.abs(jnp.mean(val, axis=(0, 1, 2))))
                summaries['%s_%d_mean_std' %
                          (name, summary_ind[0] // 2)] = jnp.mean(
                              jnp.std(val, axis=(0, 1, 2)))
                summary_ind[0] += 1

        penalty = 0

        activation_f = get_activation_f(activation_f, train, softplus_scale,
                                        bias_scale)
        norm = get_norm(activation_f, normalization, train)

        conv = get_conv(activation_f, bias_scale, weight_norm,
                        compensate_padding, normalization)

        def resnet_layer(
            inputs,
            penalty,
            filters,
            kernel_size=3,
            strides=1,
            activation=None,
        ):
            """2D Convolution-Batch Normalization-Activation stack builder."""
            x = inputs
            x = conv(x,
                     filters, (kernel_size, kernel_size),
                     strides=(strides, strides),
                     padding='SAME',
                     name='conv%d' % next(conv_index))
            x = norm(x, name='norm%d' % next(bn_index))
            add_summary('postnorm', x)
            if std_penalty_mult > 0:
                penalty += std_penalty(x)

            if activation:
                x = activation_f(x, features=x.shape[-1])
            add_summary('postact', x)
            return x, penalty

        # Main network code.
        num_res_blocks = (depth - 2) // 6

        if (depth - 2) % 6 != 0:
            raise ValueError('depth must be 6n+2 (e.g. 20, 32, 44).')

        inputs = x
        add_summary('input', x)
        add_summary('inputb', x)
        if benchmark in ['cifar10', 'cifar100']:
            x, penalty = resnet_layer(inputs,
                                      penalty,
                                      filters=filters,
                                      activation=True)
            head_kernel_init = nn.initializers.lecun_normal()
        elif benchmark in ['imagenet']:
            head_kernel_init = nn.initializers.zeros
            x, penalty = resnet_layer(inputs,
                                      penalty,
                                      filters=filters,
                                      activation=False,
                                      kernel_size=7,
                                      strides=2)
            # TODO(basv): evaluate max pool v/s avg_pool in an experiment?
            # if compensate_padding:
            #   x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding="VALID")
            # else:
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        else:
            raise ValueError('Model def not prepared for benchmark %s' %
                             benchmark)

        for stack in range(3):
            for res_block in range(num_res_blocks):
                strides = 1
                if stack > 0 and res_block == 0:  # First layer but not first stack.
                    strides = 2  # Downsample.
                y, penalty = resnet_layer(
                    x,
                    penalty,
                    filters=filters,
                    strides=strides,
                    activation=True,
                )
                y, penalty = resnet_layer(
                    y,
                    penalty,
                    filters=filters,
                    activation=False,
                )
                if stack > 0 and res_block == 0:  # First layer but not first stack.
                    # Linear projection residual shortcut to match changed dims.
                    x, penalty = resnet_layer(
                        x,
                        penalty,
                        filters=filters,
                        kernel_size=1,
                        strides=strides,
                        activation=False,
                    )

                if use_residual == 1:
                    # Apply an up projection in case of channel mismatch
                    x = x + y
                elif use_residual == 2:
                    x = (x + y) / jnp.sqrt(
                        1**2 + 1**2)  # Sum of independent normals.
                else:
                    x = y

                add_summary('postres', x)
                x = activation_f(x, features=x.shape[-1])
                add_summary('postresact', x)
            filters *= 2

        # V1 does not use BN after last shortcut connection-ReLU.
        if not no_head:
            x = jnp.mean(x, axis=(1, 2))
            add_summary('postpool', x)
            x = x.reshape((x.shape[0], -1))

            x = nn.Dense(x, num_outputs, kernel_init=head_kernel_init)
        return x, penalty, summaries
Beispiel #16
0
 def apply(self, x):
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     return nn.Dense(x, features=2)
Beispiel #17
0
    def apply(self, x, labels, config, train=True):
        # config parsing
        nf = config.model.nf
        act = get_act(config)
        normalize = get_normalization(config)
        sigmas = utils.get_sigmas(config)

        nf = config.model.nf
        ch_mult = config.model.ch_mult
        num_res_blocks = config.model.num_res_blocks
        attn_resolutions = config.model.attn_resolutions
        dropout = config.model.dropout
        resamp_with_conv = config.model.resamp_with_conv
        num_resolutions = len(ch_mult)

        # timestep/scale embedding
        timesteps = labels  # sigmas[labels] / jnp.max(sigmas)
        temb = layers.get_timestep_embedding(timesteps, nf)
        temb = nn.Dense(temb, nf * 4, kernel_init=default_initializer())
        temb = nn.Dense(act(temb), nf * 4, kernel_init=default_initializer())

        AttnBlock = layers.AttnBlock.partial(normalize=normalize)

        if config.model.conditional:
            # Condition on noise levels.
            ResnetBlock = ResnetBlockDDPM.partial(act=act,
                                                  normalize=normalize,
                                                  dropout=dropout,
                                                  temb=temb,
                                                  train=train)
        else:
            # Do not condition on noise levels explicitly.
            ResnetBlock = ResnetBlockDDPM.partial(act=act,
                                                  normalize=normalize,
                                                  dropout=dropout,
                                                  temb=None,
                                                  train=train)

        if config.data.centered:
            # Input is in [-1, 1]
            h = x
        else:
            # Input is in [0, 1]
            h = 2 * x - 1.

        # Downsampling block
        hs = [conv3x3(h, nf)]
        for i_level in range(num_resolutions):
            # Residual blocks for this resolution
            for i_block in range(num_res_blocks):
                h = ResnetBlock(hs[-1], out_ch=nf * ch_mult[i_level])
                if h.shape[1] in attn_resolutions:
                    h = AttnBlock(h)
                hs.append(h)
            if i_level != num_resolutions - 1:
                hs.append(Downsample(hs[-1], with_conv=resamp_with_conv))

        h = hs[-1]
        h = ResnetBlock(h)
        h = AttnBlock(h)
        h = ResnetBlock(h)

        # Upsampling block
        for i_level in reversed(range(num_resolutions)):
            for i_block in range(num_res_blocks + 1):
                h = ResnetBlock(jnp.concatenate([h, hs.pop()], axis=-1),
                                out_ch=nf * ch_mult[i_level])
            if h.shape[1] in attn_resolutions:
                h = AttnBlock(h)
            if i_level != 0:
                h = Upsample(h, with_conv=resamp_with_conv)

        assert not hs

        h = act(normalize(h))
        h = conv3x3(h, x.shape[-1], init_scale=0.)

        if config.model.scale_by_sigma:
            # Divide the output by sigmas. Useful for training with the NCSN loss.
            # The DDPM loss scales the network output by sigma in the loss function,
            # so no need of doing it here.
            used_sigmas = sigmas[labels].reshape(
                (x.shape[0], *([1] * len(x.shape[1:]))))
            h = h / used_sigmas

        return h
Beispiel #18
0
    def apply(self,
              x,
              *,
              num_layers,
              num_heads,
              dim_hidden=None,
              pooling=Pooling.NONE,
              is_training):
        """Transformer.

    Args:
      x: Input tensor of shape (batch, sequence_length, dim).
      num_layers: Number of layers.
      num_heads: Number of attention heads.
      dim_hidden: Dimension of the representations, default to last dimension of
        `x`.
      pooling: Optional pooling of the output tokens representations to obtain a
        single representation of the sequence.
      is_training: Whether the model is being trained.

    Returns:
      The sequences representations (batch, sequence_length, dim_hidden).
    """
        pooling = Pooling(pooling)
        dim_hidden = dim_hidden or x.shape[-1]

        if dim_hidden != x.shape[-1]:
            x = nn.Dense(x, features=dim_hidden)

        if pooling == Pooling.CLS:
            cls_token = self.param("cls_token",
                                   shape=(1, 1, dim_hidden),
                                   initializer=jax.nn.initializers.normal(
                                       1. / dim_hidden))
            batch_size = x.shape[0]
            cls_token = cls_token.repeat(batch_size, axis=0)
            x = jnp.concatenate([cls_token, x], axis=1)

        self_attention = nn.MultiHeadDotProductAttention.partial(
            num_heads=num_heads, deterministic=not is_training, inputs_kv=None)

        layer = TransformerLayer.partial(self_attention_module=self_attention,
                                         dim_intermediate=2 * dim_hidden,
                                         with_aux_outputs=False)

        transformer = stacked_layers.StackedLayers.partial(
            layer=layer, num_layers=num_layers, with_aux_outputs=False)

        representations = transformer(x)

        if pooling == Pooling.NONE:
            output = representations
        if pooling == Pooling.MEAN:
            output = representations.mean(axis=1)
        if pooling == Pooling.SUM:
            output = representations.sum(axis=1)
        if pooling == Pooling.MAX:
            output = representations.max(axis=1)
        if pooling == Pooling.CLS:
            output = representations[:, 0, :]

        return output
Beispiel #19
0
            def apply(self, x):
                x = nn.Dense(x, hidden_reps_dim, bias=True, name='l1')
                x = nn.relu(x)
                x = nn.Dense(x, hidden_reps_dim, bias=True, name='l2')

                return x
Beispiel #20
0
 def apply(self, x):
     x = nn.Dense(x, hidden_reps_dim, name='l1', bias=True)
     return x
Beispiel #21
0
    def apply(self,
              inputs,
              num_outputs,
              num_filters=64,
              num_layers=50,
              dropout_rate=0.0,
              input_dropout_rate=0.0,
              train=True,
              dtype=jnp.float32,
              head_bias_init=jnp.zeros,
              return_activations=False,
              input_layer_key='input',
              has_discriminator=False,
              discriminator=False):
        """Apply a ResNet network on the input.

    Args:
      inputs: jnp array; Inputs.
      num_outputs: int; Number of output units.
      num_filters: int; Determines base number of filters. Number of filters in
        block i is  num_filters * 2 ** i.
      num_layers: int; Number of layers (should be one of the predefined ones.)
      dropout_rate: float; Rate of dropping out the output of different hidden
        layers.
      input_dropout_rate: float; Rate of dropping out the input units.
      train: bool; Is train?
      dtype: jnp type; Type of the outputs.
      head_bias_init: fn(rng_key, shape)--> jnp array; Initializer for head bias
        parameters.
      return_activations: bool; If True hidden activation are also returned.
      input_layer_key: str; Determines where to plugin the input (this is to
        enable providing inputs to slices of the model). If `input_layer_key` is
        `layer_i` we assume the inputs are the activations of `layer_i` and pass
        them to `layer_{i+1}`.
      has_discriminator: bool; Whether the model should have discriminator
        layer.
      discriminator: bool; Whether we should return discriminator logits.

    Returns:
      Unnormalized Logits with shape `[bs, num_outputs]`,
      if return_activations:
        Logits, dict of hidden activations and the key to the representation(s)
        which will be used in as ``The Representation'', e.g., for computing
        losses.
    """
        if num_layers not in ResNet._block_size_options:
            raise ValueError('Please provide a valid number of layers')

        block_sizes = ResNet._block_size_options[num_layers]

        layer_activations = collections.OrderedDict()
        input_is_set = False
        current_rep_key = 'input'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True

        if input_is_set:
            # Input dropout
            x = nn.dropout(x, input_dropout_rate, deterministic=not train)
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        current_rep_key = 'init_conv'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key
        elif input_is_set:
            # First block
            x = nn.Conv(x,
                        num_filters, (7, 7), (2, 2),
                        padding=[(3, 3), (3, 3)],
                        bias=False,
                        dtype=dtype,
                        name='init_conv')
            x = nn.BatchNorm(x,
                             use_running_average=not train,
                             momentum=0.9,
                             epsilon=1e-5,
                             dtype=dtype,
                             name='init_bn')
            x = nn.relu(x)
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        # Residual blocks
        for i, block_size in enumerate(block_sizes):

            # Stage i (each stage contains blocks of the same size).
            for j in range(block_size):
                strides = (2, 2) if i > 0 and j == 0 else (1, 1)
                current_rep_key = f'block_{i + 1}+{j}'
                if input_layer_key == current_rep_key:
                    x = inputs
                    input_is_set = True
                    layer_activations[current_rep_key] = x
                    rep_key = current_rep_key
                elif input_is_set:
                    x = ResidualBlock(x,
                                      num_filters * 2**i,
                                      strides=strides,
                                      dropout_rate=dropout_rate,
                                      train=train,
                                      dtype=dtype,
                                      name=f'block_{i + 1}_{j}')
                    layer_activations[current_rep_key] = x
                    rep_key = current_rep_key

        current_rep_key = 'avg_pool'
        if input_layer_key == current_rep_key:
            x = inputs
            input_is_set = True
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key
        elif input_is_set:
            # Global Average Pool
            x = jnp.mean(x, axis=(1, 2))
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

        # DANN module
        if has_discriminator:
            z = dann_utils.flip_grad_identity(x)
            z = nn.Dense(z, 2, name='disc_l1', bias=True)
            z = nn.relu(z)
            z = nn.Dense(z, 2, name='disc_l2', bias=True)

        current_rep_key = 'head'
        if input_layer_key == current_rep_key:
            x = inputs
            layer_activations[current_rep_key] = x
            rep_key = current_rep_key

            logging.warn('Input was never used')
        elif input_is_set:
            x = nn.Dense(x,
                         num_outputs,
                         dtype=dtype,
                         bias_init=head_bias_init,
                         name='head')

        # Make sure that the output is float32, even if our previous computations
        # are in float16, or other types.
        x = jnp.asarray(x, jnp.float32)

        outputs = x
        if return_activations:
            outputs = (x, layer_activations, rep_key)
            if discriminator and has_discriminator:
                outputs = outputs + (z, )
        else:
            del layer_activations
            if discriminator and has_discriminator:
                outputs = (x, z)
        if discriminator and (not has_discriminator):
            raise ValueError(
                'Incosistent values passed for discriminator and has_discriminator'
            )
        return outputs
Beispiel #22
0
    def apply(self,
              x,
              *,
              patch_size,
              k,
              downscale,
              scorer_has_se,
              normalization_str="identity",
              selection_method,
              selection_method_kwargs=None,
              selection_method_inference=None,
              patch_dropout=0.,
              hard_topk_probability=0.,
              random_patch_probability=0.,
              use_iterative_extraction,
              append_position_to_input,
              feature_network,
              aggregation_method,
              aggregation_method_kwargs=None,
              train):
        """Process a high resolution image by selecting a subset of useful patches.

    This model processes the input as follow:
    1. Compute scores per patch on a downscaled version of the input.
    2. Select "important" patches using sampling or top-k methods.
    3. Extract the patches from the high-resolution image.
    4. Compute representation vector for each patch with a feature network.
    5. Aggregate the patch representation to obtain an image representation.

    Args:
      x: Input tensor of shape (batch, height, witdh, channels).
      patch_size: Size of the (squared) patches to extract.
      k: Number of patches to extract per image.
      downscale: Downscale multiplier for the input of the scorer network.
      scorer_has_se: Whether scorer network has Squeeze-excite layers.
      normalization_str: String specifying the normalization of the scores.
      selection_method: Method that selects which patches should be extracted,
        based on their scores. Either returns indices (hard selection) or
        indicators vectors (which could yield interpolated patches).
      selection_method_kwargs: Keyword args for the selection_method.
      selection_method_inference: Selection method used at inference.
      patch_dropout: Probability to replace a patch by 0 values.
      hard_topk_probability: Probability to use the true topk on the scores to
        select the patches. This operation has no gradient so scorer's weights
        won't be trained.
      random_patch_probability: Probability to replace each patch by a random
        patch in the image during training.
      use_iterative_extraction: If True, uses a for loop instead of patch
        indexing for memory efficiency.
      append_position_to_input: Append normalized (height, width) position to
        the channels of the input.
      feature_network: Network to be applied on each patch individually to
        obtain patch representation vectors.
      aggregation_method: Method to aggregate the representations of the k
        patches of each image to obtain the image representation.
      aggregation_method_kwargs: Keywords arguments for aggregation_method.
      train: If the model is being trained. Disable dropout otherwise.

    Returns:
      A representation vector for each image in the batch.
    """
        selection_method = SelectionMethod(selection_method)
        aggregation_method = AggregationMethod(aggregation_method)
        if selection_method_inference:
            selection_method_inference = SelectionMethod(
                selection_method_inference)

        selection_method_kwargs = selection_method_kwargs or {}
        aggregation_method_kwargs = aggregation_method_kwargs or {}

        stats = {}

        # Compute new dimension of the scoring image.
        b, h, w, c = x.shape
        scoring_shape = (b, h // downscale, w // downscale, c)

        # === Compute the scores with a small CNN.
        if selection_method == SelectionMethod.RANDOM:
            scores_h, scores_w = Scorer.compute_output_size(
                h // downscale, w // downscale)
            num_patches = scores_h * scores_w
        else:
            # Downscale input to run scorer on.
            scoring_x = jax.image.resize(x, scoring_shape, method="bilinear")
            scores = Scorer(scoring_x,
                            use_squeeze_excite=scorer_has_se,
                            name="scorer")
            flatten_scores = einops.rearrange(scores, "b h w -> b (h w)")
            num_patches = flatten_scores.shape[-1]
            scores_h, scores_w = scores.shape[1:3]

            # Compute entropy before normalization
            prob_scores = jax.nn.softmax(flatten_scores)
            stats["entropy_before_normalization"] = jax.scipy.special.entr(
                prob_scores).sum(axis=1).mean(axis=0)

            # Normalize the flatten scores
            normalization_fn = create_normalization_fn(normalization_str)
            flatten_scores = normalization_fn(flatten_scores)
            scores = flatten_scores.reshape(scores.shape)
            stats["scores"] = scores[Ellipsis, None]

        # Concatenate height and width position to the input channels.
        if append_position_to_input:
            coords = utils.create_grid([h, w], value_range=(0., 1.))
            x = jnp.concatenate(
                [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1)
            c += 2

        # Overwrite the selection method at inference
        if selection_method_inference and not train:
            selection_method = selection_method_inference

        # === Patch selection

        # Select the patches by sampling or top-k. Some methods returns the indices
        # of the selected patches, other methods return indicator vectors.
        extract_by_indices = selection_method in [
            SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM
        ]
        if selection_method is SelectionMethod.SINKHORN_TOPK:
            indicators = select_patches_sinkhorn_topk(
                flatten_scores, k=k, **selection_method_kwargs)
        elif selection_method is SelectionMethod.PERTURBED_TOPK:
            sigma = selection_method_kwargs["sigma"]
            num_samples = selection_method_kwargs["num_samples"]
            sigma *= self.state("sigma_mutiplier",
                                shape=(),
                                initializer=nn.initializers.ones).value
            stats["sigma"] = sigma
            indicators = select_patches_perturbed_topk(flatten_scores,
                                                       k=k,
                                                       sigma=sigma,
                                                       num_samples=num_samples)
        elif selection_method is SelectionMethod.HARD_TOPK:
            indices = select_patches_hard_topk(flatten_scores, k=k)
        elif selection_method is SelectionMethod.RANDOM:
            batch_random_indices_fn = jax.vmap(
                functools.partial(jax.random.choice,
                                  a=num_patches,
                                  shape=(k, ),
                                  replace=False))
            indices = batch_random_indices_fn(
                jax.random.split(nn.make_rng(), b))

        # Compute scores entropy for regularization
        if selection_method not in [SelectionMethod.RANDOM]:
            prob_scores = flatten_scores
            # Normalize the scores if it is not already done.
            if "softmax" not in normalization_str:
                prob_scores = jax.nn.softmax(prob_scores)
            stats["entropy"] = jax.scipy.special.entr(prob_scores).sum(
                axis=1).mean(axis=0)

        # Randomly use hard topk at training.
        if (train and hard_topk_probability > 0 and selection_method
                not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]):
            true_indices = select_patches_hard_topk(flatten_scores, k=k)
            random_values = jax.random.uniform(nn.make_rng(), (b, ))
            use_hard = random_values < hard_topk_probability
            if extract_by_indices:
                indices = jnp.where(use_hard[:, None], true_indices, indices)
            else:
                true_indicators = make_indicators(true_indices, num_patches)
                indicators = jnp.where(use_hard[:, None, None],
                                       true_indicators, indicators)

        # Sample some random patches during training with random_patch_probability.
        if (train and random_patch_probability > 0
                and selection_method is not SelectionMethod.RANDOM):
            single_random_patches = functools.partial(jax.random.choice,
                                                      a=num_patches,
                                                      shape=(k, ),
                                                      replace=False)
            random_indices = jax.vmap(single_random_patches)(jax.random.split(
                nn.make_rng(), b))
            random_values = jax.random.uniform(nn.make_rng(), (b, k))
            use_random = random_values < random_patch_probability
            if extract_by_indices:
                indices = jnp.where(use_random, random_indices, indices)
            else:
                random_indicators = make_indicators(random_indices,
                                                    num_patches)
                indicators = jnp.where(use_random[:, None, :],
                                       random_indicators, indicators)

        # === Patch extraction
        if extract_by_indices:
            patches = extract_patches_from_indices(x,
                                                   indices,
                                                   patch_size=patch_size,
                                                   grid_shape=(scores_h,
                                                               scores_w))
            indicators = make_indicators(indices, num_patches)
        else:
            patches = extract_patches_from_indicators(
                x,
                indicators,
                patch_size,
                grid_shape=(scores_h, scores_w),
                iterative=use_iterative_extraction,
                patch_dropout=patch_dropout,
                train=train)

        chex.assert_shape(patches, (b, k, patch_size, patch_size, c))

        stats["extracted_patches"] = einops.rearrange(
            patches, "b k i j c -> b i (k j) c")
        # Remove position channels for plotting.
        if append_position_to_input:
            stats["extracted_patches"] = (
                stats["extracted_patches"][Ellipsis, :-2])

        # === Compute patch features
        flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c")
        representations = feature_network(flatten_patches, train=train)
        if representations.ndim > 2:
            collapse_axis = tuple(range(1, representations.ndim - 1))
            representations = representations.mean(axis=collapse_axis)
        representations = einops.rearrange(representations,
                                           "(b k) d -> b k d",
                                           k=k)

        stats["patch_representations"] = representations

        # === Aggregate the k patches

        # - for sampling we are forced to take an expectation
        # - for topk we have multiple options: mean, max, transformer.
        if aggregation_method is AggregationMethod.TRANSFORMER:
            patch_pos_encoding = nn.Dense(einops.rearrange(
                indicators, "b d k -> b k d"),
                                          features=representations.shape[-1])

            chex.assert_equal_shape([representations, patch_pos_encoding])
            representations += patch_pos_encoding
            representations = transformer.Transformer(
                representations,
                **aggregation_method_kwargs,
                is_training=train)

        elif aggregation_method is AggregationMethod.MEANPOOLING:
            representations = representations.mean(axis=1)
        elif aggregation_method is AggregationMethod.MAXPOOLING:
            representations = representations.max(axis=1)
        elif aggregation_method is AggregationMethod.SUM_LAYERNORM:
            representations = representations.sum(axis=1)
            representations = nn.LayerNorm(representations)

        representations = nn.Dense(representations,
                                   features=representations.shape[-1],
                                   name="classification_dense1")
        representations = nn.swish(representations)

        return representations, stats
Beispiel #23
0
    def apply(self, x, config, num_classes, train=True):
        """Creates a model definition."""

        if config.get("append_position_to_input", False):
            b, h, w, _ = x.shape
            coords = utils.create_grid([h, w], value_range=(0., 1.))
            x = jnp.concatenate(
                [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1)

        if config.model.lower() == "cnn":
            h = models.SimpleCNNImageClassifier(x)
            h = nn.relu(h)
            stats = None
        elif config.model.lower() == "resnet":
            smallinputs = config.get("resnet.small_inputs", False)
            blocks = config.get("resnet.blocks", [3, 4, 6, 3])
            h = models.ResNet(x,
                              train=train,
                              block_sizes=blocks,
                              small_inputs=smallinputs)
            h = jnp.mean(h, axis=[1, 2])  # global average pool
            stats = None
        elif config.model.lower() == "resnet18":
            h = models.ResNet18(x, train=train)
            h = jnp.mean(h, axis=[1, 2])  # global average pool
            stats = None
        elif config.model.lower() == "resnet50":
            h = models.ResNet50(x, train=train)
            h = jnp.mean(h, axis=[1, 2])  # global average pool
            stats = None
        elif config.model.lower() == "ats-traffic":
            h = models.ATSFeatureNetwork(x, train=train)
            stats = None
        elif config.model.lower() == "patchnet":
            feature_network = {
                "resnet18":
                models.ResNet18,
                "resnet18-fourth":
                models.ResNet.partial(num_filters=16,
                                      block_sizes=(2, 2, 2, 2),
                                      block=models.BasicBlock),
                "resnet50":
                models.ResNet50,
                "ats-traffic":
                models.ATSFeatureNetwork,
            }[config.feature_network.lower()]

            selection_method = sample_patches.SelectionMethod(
                config.selection_method)
            selection_method_kwargs = {}
            if selection_method is sample_patches.SelectionMethod.SINKHORN_TOPK:
                selection_method_kwargs = config.sinkhorn_topk_kwargs
            if selection_method is sample_patches.SelectionMethod.PERTURBED_TOPK:
                selection_method_kwargs = config.perturbed_topk_kwargs

            h, stats = sample_patches.PatchNet(
                x,
                patch_size=config.patch_size,
                k=config.k,
                downscale=config.downscale,
                scorer_has_se=config.get("scorer_has_se", False),
                selection_method=config.selection_method,
                selection_method_kwargs=selection_method_kwargs,
                selection_method_inference=config.get(
                    "selection_method_inference", None),
                normalization_str=config.normalization_str,
                aggregation_method=config.aggregation_method,
                aggregation_method_kwargs=config.get(
                    "aggregation_method_kwargs", {}),
                append_position_to_input=config.get("append_position_to_input",
                                                    False),
                feature_network=feature_network,
                use_iterative_extraction=config.use_iterative_extraction,
                hard_topk_probability=config.get("hard_topk_probability", 0.),
                random_patch_probability=config.get("random_patch_probability",
                                                    0.),
                train=train)
            stats["x"] = x
        else:
            raise RuntimeError("Unknown classification model type: %s" %
                               config.model.lower())
        out = nn.Dense(h, num_classes, name="final")
        return out, stats
Beispiel #24
0
    def apply(self,
              x,
              num_outputs,
              pyramid_alpha=200,
              pyramid_depth=272,
              train=True):
        assert (pyramid_depth - 2) % 9 == 0

        # Shake-drop hyper-params
        mask_prob = 0.5
        alpha_min, alpha_max = (-1.0, 1.0)
        beta_min, beta_max = (0.0, 1.0)

        # Bottleneck network size
        blocks_per_group = (pyramid_depth - 2) // 9
        # See Eqn 2 in https://arxiv.org/abs/1610.02915
        num_channels = 16
        # N in https://arxiv.org/abs/1610.02915
        total_blocks = blocks_per_group * 3
        delta_channels = pyramid_alpha / total_blocks

        x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv')
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5,
                         name='init_bn')

        layer_num = 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels), (1, 1),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        for block_i in range(blocks_per_group):
            num_channels += delta_channels
            layer_mask_prob = _calc_shakedrop_mask_prob(
                layer_num, total_blocks, mask_prob)
            x = BottleneckShakeDrop(x,
                                    int(num_channels),
                                    ((2, 2) if block_i == 0 else (1, 1)),
                                    layer_mask_prob,
                                    alpha_min,
                                    alpha_max,
                                    beta_min,
                                    beta_max,
                                    train=train)
            layer_num += 1

        assert layer_num - 1 == total_blocks
        x = nn.BatchNorm(x,
                         use_running_average=not train,
                         momentum=0.9,
                         epsilon=1e-5,
                         name='final_bn')
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs)
        return x
  def apply(
      self,
      inputs,
      blocks_per_group,
      channel_multiplier,
      num_outputs,
      kernel_size=(3, 3),
      strides=None,
      maxpool=False,
      dropout_rate=0.0,
      dtype=jnp.float32,
      norm_layer='group_norm',
      train=True,
      return_activations=False,
      input_layer_key='input',
      has_discriminator=False,
      discriminator=False,
  ):

    norm_layer_name = ''
    if norm_layer == 'batch_norm':
      norm_layer = nn.BatchNorm.partial(use_running_average=not train)
      norm_layer_name = 'bn'
    elif norm_layer == 'group_norm':
      norm_layer = nn.GroupNorm.partial(num_groups=16)
      norm_layer_name = 'gn'

    layer_activations = collections.OrderedDict()
    input_is_set = False
    current_rep_key = 'input'
    if input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'init_conv'
    if input_is_set:
      x = nn.Conv(
          x,
          16,
          kernel_size=kernel_size,
          strides=strides,
          padding='SAME',
          name='init_conv')
      if maxpool:
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l1'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          16 * channel_multiplier,
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l1')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l2'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          32 * channel_multiplier, (2, 2),
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l2')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l3'
    if input_is_set:
      x = WideResnetGroup(
          x,
          blocks_per_group,
          64 * channel_multiplier, (2, 2),
          dropout_rate=dropout_rate,
          norm_layer=norm_layer,
          train=train,
          name='l3')
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    current_rep_key = 'l4'
    if input_is_set:
      x = norm_layer(x, name=f'{norm_layer_name}')
      x = jax.nn.relu(x)
      x = nn.avg_pool(x, (8, 8))
      x = x.reshape((x.shape[0], -1))
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key
    elif input_layer_key == current_rep_key:
      x = inputs
      input_is_set = True
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

    # DANN module
    if has_discriminator:
      z = dann_utils.flip_grad_identity(x)
      z = nn.Dense(z, 2, name='disc_l1', bias=True)
      z = nn.relu(z)
      z = nn.Dense(z, 2, name='disc_l2', bias=True)

    current_rep_key = 'head'
    if input_is_set:
      x = nn.Dense(x, num_outputs, dtype=dtype, name='head')
    else:
      x = inputs
      layer_activations[current_rep_key] = x
      rep_key = current_rep_key

      logging.warn('Input was never used')

    outputs = x
    if return_activations:
      outputs = (x, layer_activations, rep_key)
      if discriminator and has_discriminator:
        outputs = outputs + (z,)
    else:
      del layer_activations
      if discriminator and has_discriminator:
        outputs = (x, z)
    if discriminator and (not has_discriminator):
      raise ValueError(
          'Incosistent values passed for discriminator and has_discriminator')
    return outputs
Beispiel #26
0
    def apply(self, x, communication=Communication.NONE, train=True):
        """Forward pass."""
        batch_size = x.shape[0]

        if communication is Communication.SQUEEZE_EXCITE_X:
            x = sample_patches.SqueezeExciteLayer(x)
        # end if squeeze excite x

        d1 = nn.relu(
            nn.Conv(x,
                    128,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    bias=True,
                    name="down1"))
        d2 = nn.relu(
            nn.Conv(d1,
                    128,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    bias=True,
                    name="down2"))
        d3 = nn.relu(
            nn.Conv(d2,
                    128,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    bias=True,
                    name="down3"))

        if communication is Communication.SQUEEZE_EXCITE_D:
            d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")
            d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")
            d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")

            nd1 = d1_flatten.shape[1]
            nd2 = d2_flatten.shape[1]

            d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten],
                                         axis=1)

            num_channels = d_together.shape[-1]
            y = d_together.mean(axis=1)
            y = nn.Dense(y, features=num_channels // 4, bias=False)
            y = nn.relu(y)
            y = nn.Dense(y, features=num_channels, bias=False)
            y = nn.sigmoid(y)

            d_together = d_together * y[:, None, :]

            # split and reshape
            d1 = d_together[:, :nd1].reshape(d1.shape)
            d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape)
            d3 = d_together[:, nd1 + nd2:].reshape(d3.shape)

        elif communication is Communication.TRANSFORMER:
            d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")
            d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")
            d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")

            nd1 = d1_flatten.shape[1]
            nd2 = d2_flatten.shape[1]

            d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten],
                                         axis=1)

            positional_encodings = self.param(
                "scale_ratio_position_encodings",
                shape=(1, ) + d_together.shape[1:],
                initializer=jax.nn.initializers.normal(1. /
                                                       d_together.shape[-1]))
            d_together = transformer.Transformer(d_together +
                                                 positional_encodings,
                                                 num_layers=2,
                                                 num_heads=8,
                                                 is_training=train)

            # split and reshape
            d1 = d_together[:, :nd1].reshape(d1.shape)
            d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape)
            d3 = d_together[:, nd1 + nd2:].reshape(d3.shape)

        t1 = nn.Conv(d1,
                     6,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy1")
        t2 = nn.Conv(d2,
                     6,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy2")
        t3 = nn.Conv(d3,
                     9,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy3")

        raw_scores = (jnp.split(t1, 6, axis=-1) + jnp.split(t2, 6, axis=-1) +
                      jnp.split(t3, 9, axis=-1))

        # The following is for normalization.
        t = jnp.concatenate((jnp.reshape(
            t1, [batch_size, -1]), jnp.reshape(
                t2, [batch_size, -1]), jnp.reshape(t3, [batch_size, -1])),
                            axis=1)
        t_min = jnp.reshape(jnp.min(t, axis=-1), [batch_size, 1, 1, 1])
        t_max = jnp.reshape(jnp.max(t, axis=-1), [batch_size, 1, 1, 1])
        normalized_scores = zeroone(raw_scores, t_min, t_max)

        stats = {
            "scores": normalized_scores,
            "raw_scores": t,
        }
        # removes the split dimension. scores are now b x h' x w' shaped
        normalized_scores = [s.squeeze(-1) for s in normalized_scores]

        return normalized_scores, stats
Beispiel #27
0
    def apply(self,
              inputs,
              vocab_size,
              emb_dim=512,
              num_heads=8,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=2048,
              train=False,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              causal=True,
              cache=None,
              positional_encoding_module=AddLearnedPositionalEncodings,
              self_attention_module=nn.SelfAttention,
              attention_fn=None,
              pad_token=None,
              output_head='logits'):
        """Applies Transformer model on the inputs.

    Args:
      inputs: An array of shape (batch_size, length) or (batch_size, length,
        vocab_size) with the input sequences. When 2-dimensional, the array
        contains sequences of int tokens. Otherwise, the array contains
        next-token distributions over tokens (e.g. one-hot representations).
      vocab_size: An int with the size of the vocabulary.
      emb_dim: An int with the token embedding dimension.
      num_heads: An int with the number of attention heads.
      num_layers: An int with the number of transformer encoder layers.
      qkv_dim: An int with the dimension of the query/key/value vectors.
      mlp_dim: An int with the inner dimension of the feed-forward network which
        follows the attention block.
      max_len: An int with the maximum training sequence length.
      train: A bool denoting whether we are currently training.
      dropout_rate: A float with the dropout rate.
      attention_dropout_rate: A float with a dropout rate for attention weights.
      causal: Whether to apply causal masking.
      cache: Cache for decoding.
      positional_encoding_module: A module used for adding positional encodings.
      self_attention_module: Self attention module.
      attention_fn: Method to use in place of dot product attention.
      pad_token: Token to ignore in attention.
      output_head: String or iterable over strings containing the model's output
        head(s) to return.

    Returns:
      Output of a transformer decoder. If output_head is a string, we return a
        single output head output; if output_head is an iterable, we return a
        dict with (output head name, output head output) key-value pairs.
    """
        if inputs.ndim != 2 and inputs.ndim != 3:
            raise ValueError('Expected 2 or 3 dimensions, found %d.' %
                             inputs.ndim)

        if inputs.ndim == 3:
            padding_mask = jnp.ones_like(inputs[Ellipsis, 0])
        elif pad_token is None:
            padding_mask = jnp.ones_like(inputs)
        else:
            # Mask out padding tokens.
            padding_mask = jnp.where(inputs != pad_token, 1,
                                     0).astype(jnp.float32)
        padding_mask = padding_mask[Ellipsis, None]  # Add embedding dimension.

        heads = dict()
        x = inputs
        if inputs.ndim == 2:
            x = x.astype('int32')
        x = Embed(x,
                  num_embeddings=vocab_size,
                  num_features=emb_dim,
                  name='embed')

        if positional_encoding_module == AddLearnedPositionalEncodings:
            x = positional_encoding_module(
                x,
                max_len=max_len,
                cache=cache,
                posemb_init=sinusoidal_init(max_len=max_len))
        else:
            x = positional_encoding_module(x, max_len=max_len)
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
        heads['input_emb'] = x
        for i in range(num_layers):
            x = Transformer1DBlock(
                x,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                causal_mask=causal,
                padding_mask=padding_mask,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                self_attention_module=self_attention_module,
                deterministic=not train,
                attention_fn=attention_fn,
                cache=cache,
            )
            heads['layer_%s' % i] = x
        x = nn.LayerNorm(x)
        heads['output_emb'] = x * padding_mask  # Zero out PAD positions.
        if 'logits' in output_head:
            logits = nn.Dense(x,
                              vocab_size,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6))
            heads['logits'] = logits

        if 'regression' in output_head:
            regression = nn.Dense(
                x,
                1,
                kernel_init=nn.initializers.xavier_uniform(),
                bias_init=nn.initializers.normal(stddev=1e-6))
            regression = jnp.squeeze(regression, axis=-1)
            heads['regression'] = regression

        if isinstance(output_head, (tuple, list)):
            return {head: heads[head] for head in output_head}
        return heads[output_head]
Beispiel #28
0
    def apply(self, x, config, num_classes, train=True):
        """Creates a model definition."""
        b, c = x.shape[0], x.shape[3]
        k = config.k
        sigma = config.ptopk_sigma
        num_samples = config.ptopk_num_samples

        sigma *= self.state("sigma_mutiplier",
                            shape=(),
                            initializer=nn.initializers.ones).value

        stats = {"x": x, "sigma": sigma}

        feature_extractor = models.ResNet50.shared(train=train,
                                                   name="ResNet_0")

        rpn_feature = feature_extractor(x)
        rpn_scores, rpn_stats = ProposalNet(jax.lax.stop_gradient(rpn_feature),
                                            communication=Communication(
                                                config.communication),
                                            train=train)
        stats.update(rpn_stats)

        # rpn_scores are a list of score images. We keep track of the structure
        # because it is used in the aggregation step later-on.
        rpn_scores_shapes = [s.shape for s in rpn_scores]
        rpn_scores_flat = jnp.concatenate(
            [jnp.reshape(s, [b, -1]) for s in rpn_scores], axis=1)
        top_k_indicators = sample_patches.select_patches_perturbed_topk(
            rpn_scores_flat, k=k, sigma=sigma, num_samples=num_samples)
        top_k_indicators = jnp.transpose(top_k_indicators, [0, 2, 1])
        offset = 0
        weights = []
        for sh in rpn_scores_shapes:
            cur = top_k_indicators[:, :, offset:offset + sh[1] * sh[2]]
            cur = jnp.reshape(cur, [b, k, sh[1], sh[2]])
            weights.append(cur)
            offset += sh[1] * sh[2]
        chex.assert_equal(offset, top_k_indicators.shape[-1])

        part_imgs = weighted_anchor_aggregator(x, weights)
        chex.assert_shape(part_imgs, (b * k, 224, 224, c))
        stats["part_imgs"] = jnp.reshape(part_imgs, [b, k * 224, 224, c])

        part_features = feature_extractor(part_imgs)
        part_features = jnp.mean(part_features,
                                 axis=[1, 2])  # GAP the spatial dims

        part_features = nn.dropout(  # features from parts
            jnp.reshape(part_features, [b * k, 2048]),
            0.5,
            deterministic=not train,
            rng=nn.make_rng())
        features = nn.dropout(  # features from whole image
            jnp.reshape(jnp.mean(rpn_feature, axis=[1, 2]), [b, -1]),
            0.5,
            deterministic=not train,
            rng=nn.make_rng())

        # Mean pool all part features, add it to features and predict logits.
        concat_out = jnp.mean(jnp.reshape(part_features, [b, k, 2048]),
                              axis=1) + features
        concat_logits = nn.Dense(concat_out, num_classes)
        raw_logits = nn.Dense(features, num_classes)
        part_logits = jnp.reshape(nn.Dense(part_features, num_classes),
                                  [b, k, -1])

        all_logits = {
            "raw_logits": raw_logits,
            "concat_logits": concat_logits,
            "part_logits": part_logits,
        }
        # add entropy into it for entropy regularization.
        stats["rpn_scores_entropy"] = jax.scipy.special.entr(
            jax.nn.softmax(stats["raw_scores"])).sum(axis=1).mean(axis=0)
        return all_logits, stats