Beispiel #1
0
    def apply(self,
              input_ids,
              input_mask,
              type_ids,
              *,
              config,
              deterministic=False):
        """Applies BERT model on the inputs."""

        word_embeddings = nn.Embed(input_ids,
                                   num_embeddings=config.vocab_size,
                                   features=config.d_emb,
                                   embedding_init=kernel_initializer,
                                   name="word_embeddings")
        position_embeddings = layers.PositionalEncoding(
            word_embeddings,
            max_len=config.max_len,
            posemb_init=kernel_initializer,
            name="position_embeddings")
        type_embeddings = nn.Embed(type_ids,
                                   num_embeddings=config.type_vocab_size,
                                   features=config.d_emb,
                                   embedding_init=kernel_initializer,
                                   name="type_embeddings")

        embeddings = word_embeddings + position_embeddings + type_embeddings
        embeddings = nn.LayerNorm(embeddings,
                                  epsilon=LAYER_NORM_EPSILON,
                                  name="embeddings_layer_norm")
        embeddings = nn.Dense(embeddings,
                              config.d_model,
                              name="embedding_hidden_mapping_in")
        embeddings = nn.dropout(embeddings,
                                rate=config.dropout_rate,
                                deterministic=deterministic)

        # Transformer blocks
        feed_forward = layers.FeedForward.partial(
            d_ff=config.d_ff,
            dropout_rate=config.dropout_rate,
            intermediate_activation=hidden_activation,
            kernel_init=kernel_initializer)

        self_attention = efficient_attention.BertSelfAttention.partial(
            num_heads=config.num_heads,
            num_parallel_heads=config.num_parallel_heads,
            d_qkv=config.d_model // config.num_heads,
            attention_dropout_rate=config.attention_dropout_rate,
            output_dropout_rate=config.dropout_rate,
            kernel_init=kernel_initializer,
            output_kernel_init=kernel_initializer)

        hidden_states = embeddings
        mask = input_mask.astype(jnp.int32)
        shared_encoder_layer = layers.TransformerBlock.shared(
            feed_forward=feed_forward,
            attention=self_attention,
            deterministic=deterministic,
            name="encoder_layer_0")
        for _ in range(config.num_layers):
            hidden_states = shared_encoder_layer(hidden_states, mask)

        pooled_output = nn.Dense(hidden_states[:, 0],
                                 config.d_model,
                                 kernel_init=kernel_initializer,
                                 name="pooler")
        pooled_output = jnp.tanh(pooled_output)

        return hidden_states, pooled_output
Beispiel #2
0
    def apply(self,
              input_ids,
              input_mask,
              type_ids,
              *,
              config,
              deterministic=False):
        """Applies BERT model on the inputs."""

        word_embeddings = nn.Embed(input_ids,
                                   num_embeddings=config.vocab_size,
                                   features=config.hidden_size,
                                   embedding_init=get_kernel_init(config),
                                   name='word_embeddings')
        position_embeddings = layers.PositionalEncoding(
            word_embeddings,
            max_len=config.max_position_embeddings,
            posemb_init=get_kernel_init(config),
            name='position_embeddings')
        type_embeddings = nn.Embed(type_ids,
                                   num_embeddings=config.type_vocab_size,
                                   features=config.hidden_size,
                                   embedding_init=get_kernel_init(config),
                                   name='type_embeddings')

        embeddings = word_embeddings + position_embeddings + type_embeddings
        embeddings = nn.LayerNorm(embeddings,
                                  epsilon=LAYER_NORM_EPSILON,
                                  name='embeddings_layer_norm')
        embeddings = nn.dropout(embeddings,
                                rate=config.hidden_dropout_prob,
                                deterministic=deterministic)

        # Transformer blocks
        feed_forward = layers.FeedForward.partial(
            d_ff=config.intermediate_size,
            dropout_rate=config.hidden_dropout_prob,
            intermediate_activation=get_hidden_activation(config),
            kernel_init=get_kernel_init(config))

        attention = efficient_attention.BertSelfAttention.partial(
            num_heads=config.num_attention_heads,
            num_parallel_heads=None,
            d_qkv=config.hidden_size // config.num_attention_heads,
            attention_dropout_rate=config.attention_probs_dropout_prob,
            output_dropout_rate=config.hidden_dropout_prob,
            kernel_init=get_kernel_init(config),
            output_kernel_init=get_kernel_init(config))

        hidden_states = embeddings
        mask = input_mask.astype(jnp.int32)
        for layer_num in range(config.num_hidden_layers):
            hidden_states = layers.TransformerBlock(
                hidden_states,
                mask,
                feed_forward=feed_forward,
                attention=attention,
                deterministic=deterministic,
                name=f'encoder_layer_{layer_num}')

        pooled_output = nn.Dense(hidden_states[:, 0],
                                 config.hidden_size,
                                 kernel_init=get_kernel_init(config),
                                 name='pooler')
        pooled_output = jnp.tanh(pooled_output)

        return hidden_states, pooled_output
  def apply(self, x, labels, y=None, config=None, train=True):
    # config parsing
    nf = config.model.nf
    act = get_act(config)
    normalize = get_normalization(config)
    sigmas = utils.get_sigmas(config)

    nf = config.model.nf
    ch_mult = config.model.ch_mult
    num_res_blocks = config.model.num_res_blocks
    attn_resolutions = config.model.attn_resolutions
    dropout = config.model.dropout
    resamp_with_conv = config.model.resamp_with_conv
    num_resolutions = len(ch_mult)

    conditional = config.model.conditional  # noise-conditional
    fir = config.model.fir
    fir_kernel = config.model.fir_kernel
    skip_rescale = config.model.skip_rescale
    resblock_type = config.model.resblock_type
    progressive = config.model.progressive
    progressive_input = config.model.progressive_input
    init_scale = config.model.init_scale
    assert progressive.lower() in ['none', 'output_skip', 'residual']
    assert config.model.embedding_type.lower() in ['gaussian', 'positional']
    combine_method = config.model.progressive_combine
    combiner = Combine.partial(method=combine_method)

    # timestep/noise_level embedding
    if config.model.embedding_type == 'gaussian':
      # Gaussian Fourier features embeddings.
      used_sigmas = sigmas[labels]
      temb = layersv3.GaussianFourierProjection(
          jnp.log(used_sigmas),
          embedding_size=nf,
          scale=config.model.fourier_scale)
    elif config.model.embedding_type == 'positional':
      # Sinusoidal positional embeddings.
      timesteps = labels
      temb = layers.get_timestep_embedding(timesteps, nf)
    else:
      raise ValueError(f'embedding type {config.model.embedding_type} unknown.')

    temb = nn.Dense(temb, nf * 4, kernel_init=default_initializer())
    temb = nn.Dense(act(temb), nf * 4, kernel_init=default_initializer())

    if y is not None:  # class-conditional image generation
      class_embed = nn.Embed(y, config.data.num_classes, nf * 4)
      class_embed = nn.Dense(
          class_embed, nf * 4, kernel_init=default_initializer())
      class_embed = nn.Dense(
          act(class_embed), nf * 4, kernel_init=default_initializer())

      temb += class_embed

    AttnBlock = layersv3.AttnBlockv3.partial(
        normalize=normalize,
        init_scale=init_scale,
        skip_rescale=skip_rescale)

    Upsample = layersv3.Upsample.partial(
        with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
    if progressive == 'output_skip':
      pyramid_upsample = layersv3.Upsample.partial(
          fir=fir, fir_kernel=fir_kernel, with_conv=False)
    elif progressive == 'residual':
      pyramid_upsample = layersv3.Upsample.partial(
          fir=fir, fir_kernel=fir_kernel, with_conv=True)

    Downsample = layersv3.Downsample.partial(
        with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)

    if progressive_input == 'input_skip':
      pyramid_downsample = layersv3.Downsample.partial(
          fir=fir, fir_kernel=fir_kernel, with_conv=False)
    elif progressive_input == 'residual':
      pyramid_downsample = layersv3.Downsample.partial(
          fir=fir, fir_kernel=fir_kernel, with_conv=True)

    if resblock_type == 'ddpm':
      ResnetBlock = ResnetBlockDDPM.partial(
          act=act,
          normalize=normalize,
          dropout=dropout,
          temb=temb if conditional else None,
          train=train,
          init_scale=init_scale,
          skip_rescale=skip_rescale)

    elif resblock_type == 'biggan':
      ResnetBlock = ResnetBlockBigGAN.partial(
          act=act,
          normalize=normalize,
          temb=temb if conditional else None,
          train=train,
          dropout=dropout,
          fir=fir,
          fir_kernel=fir_kernel,
          init_scale=init_scale,
          skip_rescale=skip_rescale)

    else:
      raise ValueError(f'resblock_type {resblock_type} unrecognized.')

    if not config.data.centered:
      # If input data is in [0, 1]
      x = 2 * x - 1.

    # Downsampling block

    input_pyramid = None
    if progressive_input != 'none':
      input_pyramid = x

    hs = [conv3x3(x, nf)]
    for i_level in range(num_resolutions):
      # Residual blocks for this resolution
      for i_block in range(num_res_blocks):
        h = ResnetBlock(hs[-1], out_ch=nf * ch_mult[i_level])
        if h.shape[1] in attn_resolutions:
          h = AttnBlock(h)
        hs.append(h)

      if i_level != num_resolutions - 1:
        if resblock_type == 'ddpm':
          h = Downsample(hs[-1])
        else:
          h = ResnetBlock(hs[-1], down=True)

        if progressive_input == 'input_skip':
          input_pyramid = pyramid_downsample(input_pyramid)
          h = combiner(input_pyramid, h)

        elif progressive_input == 'residual':
          input_pyramid = pyramid_downsample(
              input_pyramid, out_ch=h.shape[-1])
          if skip_rescale:
            input_pyramid = (input_pyramid + h) / np.sqrt(2.)
          else:
            input_pyramid = input_pyramid + h
          h = input_pyramid

        hs.append(h)

    h = hs[-1]
    h = ResnetBlock(h)
    h = AttnBlock(h)
    h = ResnetBlock(h)

    pyramid = None

    # Upsampling block
    for i_level in reversed(range(num_resolutions)):
      for i_block in range(num_res_blocks + 1):
        h = ResnetBlock(
            jnp.concatenate([h, hs.pop()], axis=-1),
            out_ch=nf * ch_mult[i_level])

      if h.shape[1] in attn_resolutions:
        h = AttnBlock(h)

      if progressive != 'none':
        if i_level == num_resolutions - 1:
          if progressive == 'output_skip':
            pyramid = conv3x3(
                act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))),
                x.shape[-1],
                bias=True,
                init_scale=init_scale)
          elif progressive == 'residual':
            pyramid = conv3x3(
                act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))),
                h.shape[-1],
                bias=True)
          else:
            raise ValueError(f'{progressive} is not a valid name.')
        else:
          if progressive == 'output_skip':
            pyramid = pyramid_upsample(pyramid)
            pyramid = pyramid + conv3x3(
                act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))),
                x.shape[-1],
                bias=True,
                init_scale=init_scale)
          elif progressive == 'residual':
            pyramid = pyramid_upsample(pyramid, out_ch=h.shape[-1])
            if skip_rescale:
              pyramid = (pyramid + h) / np.sqrt(2.)
            else:
              pyramid = pyramid + h
            h = pyramid
          else:
            raise ValueError(f'{progressive} is not a valid name')

      if i_level != 0:
        if resblock_type == 'ddpm':
          h = Upsample(h)
        else:
          h = ResnetBlock(h, up=True)

    assert not hs

    if progressive == 'output_skip':
      h = pyramid
    else:
      h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32)))
      h = conv3x3(h, x.shape[-1], init_scale=init_scale)

    if config.model.scale_by_sigma:
      used_sigmas = sigmas[labels].reshape((x.shape[0],
                                            *([1] * len(x.shape[1:]))))
      h = h / used_sigmas

    return h