예제 #1
0
 def setup(self):
   self.layers = [nn.Dense(self.ch) for _ in range(self.n_layers)]
 def setup(self):
     self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
     self.seq_relationship = nn.Dense(2, dtype=self.dtype)
예제 #3
0
    def __call__(self,
                 inputs,
                 train,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      train: bool: if model is training.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output of a transformer decoder.
    """
        assert inputs.ndim == 2  # (batch, len)
        dtype = utils.dtype_from_str(self.model_dtype)

        if self.decode:
            # for fast autoregressive decoding we use no decoder mask
            decoder_mask = None
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype),
                nn.make_causal_mask(inputs, dtype=dtype))

        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=dtype))

        y = inputs.astype('int32')
        if not self.decode:
            y = shift_inputs(y, segment_ids=inputs_segmentation)

        # TODO(gdahl,znado): this code appears to be accessing out-of-bounds
        # indices for dataset_lib:proteins_test. This will break when jnp.take() is
        # updated to return NaNs for out-of-bounds indices.
        # Debug why this is the case.
        y = jnp.clip(y, 0, self.vocab_size - 1)

        if self.shared_embedding is None:
            output_embed = nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            output_embed = self.shared_embedding

        y = output_embed(y)

        y = AddPositionEmbs(max_len=self.max_len,
                            posemb_init=sinusoidal_init(max_len=self.max_len),
                            decode=self.decode,
                            name='posembed_output')(
                                y, inputs_positions=inputs_positions)
        y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train)

        y = y.astype(dtype)

        for _ in range(self.num_layers):
            y = Transformer1DBlock(
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                attention_fn=self.attention_fn,
                normalizer=self.normalizer,
                dtype=dtype)(
                    inputs=y,
                    train=train,
                    decoder_mask=decoder_mask,
                    encoder_decoder_mask=None,
                    inputs_positions=None,
                    inputs_segmentation=None,
                )
        if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(self.normalizer,
                                                         train,
                                                         dtype=dtype)
            y = maybe_normalize()(y)

        if self.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(y.astype(jnp.float32))
            # Correctly normalize pre-softmax logits for this shared case.
            logits = logits / jnp.sqrt(y.shape[-1])
        else:
            logits = nn.Dense(self.vocab_size,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6),
                              dtype=dtype,
                              name='logits_dense')(y)

        return logits.astype(dtype)
 def setup(self):
     self.dense = nn.Dense(
         self.config.hidden_size,
         kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
         dtype=self.dtype,
     )
 def setup(self):
     self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
     self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
     self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
 def setup(self):
     self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
     self.activation = ACT2FN[self.config.hidden_act]
     self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size,
                                        dtype=self.dtype)
예제 #7
0
 def embed(inputs):
     return nn.Dense(latent_size)(inputs)
예제 #8
0
 def setup(self):
     self.roformer = FlaxRoFormerModule(config=self.config,
                                        dtype=self.dtype)
     self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
예제 #9
0
 def setup(self):
     self.roformer = FlaxRoFormerModule(config=self.config,
                                        dtype=self.dtype)
     self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
     self.classifier = nn.Dense(1, dtype=self.dtype)
예제 #10
0
  def __call__(self, x, t, y, *, train):

    assert x.dtype == jnp.int32

    x_onehot = jax.nn.one_hot(x, num_classes=self.num_pixel_vals)
    # Convert to float and scale image to [-1, 1]
    x = utils.normalize_data(x.astype(jnp.float32))

    batch_size, height, width, _ = x.shape
    assert height == width
    assert x.dtype in (jnp.float32, jnp.float64)
    assert t.shape == (batch_size,)  # and t.dtype == jnp.int32
    num_resolutions = len(self.ch_mult)
    ch = self.ch

    # Class embedding
    assert self.num_classes >= 1
    if self.num_classes > 1:
      logging.info('conditional: num_classes=%d', self.num_classes)
      assert y.shape == (batch_size,) and y.dtype == jnp.int32
      y = jax.nn.one_hot(y, num_classes=self.num_classes, dtype=x.dtype)
      y = nn.Dense(features=ch * 4, name='class_emb')(y)
      assert y.shape == (batch_size, ch * 4)
    else:
      logging.info('unconditional: num_classes=%d', self.num_classes)
      y = None

    # Timestep embedding
    logging.info('model max_time: %f', self.max_time)
    temb = get_timestep_embedding(t, ch, max_time=self.max_time)
    temb = nn.Dense(features=ch * 4, name='dense0')(temb)
    temb = nn.Dense(features=ch * 4, name='dense1')(nonlinearity(temb))
    assert temb.shape == (batch_size, ch * 4)

    # Downsampling
    hs = [nn.Conv(
        features=ch, kernel_size=(3, 3), strides=(1, 1), name='conv_in')(x)]
    for i_level in range(num_resolutions):
      # Residual blocks for this resolution
      for i_block in range(self.num_res_blocks):
        h = ResnetBlock(
            out_ch=ch * self.ch_mult[i_level],
            dropout=self.dropout,
            name=f'down_{i_level}.block_{i_block}')(
                hs[-1], temb=temb, y=y, deterministic=not train)
        if h.shape[1] in self.attn_resolutions:
          h = AttnBlock(
              num_heads=self.num_heads,
              name=f'down_{i_level}.attn_{i_block}')(h)
        hs.append(h)
      # Downsample
      if i_level != num_resolutions - 1:
        hs.append(self._downsample(hs[-1], name=f'down_{i_level}.downsample'))

    # Middle
    h = hs[-1]
    h = ResnetBlock(dropout=self.dropout, name='mid.block_1')(
        h, temb=temb, y=y, deterministic=not train)
    h = AttnBlock(num_heads=self.num_heads, name='mid.attn_1')(h)
    h = ResnetBlock(dropout=self.dropout, name='mid.block_2')(
        h, temb=temb, y=y, deterministic=not train)

    # Upsampling
    for i_level in reversed(range(num_resolutions)):
      # Residual blocks for this resolution
      for i_block in range(self.num_res_blocks + 1):
        h = ResnetBlock(
            out_ch=ch * self.ch_mult[i_level],
            dropout=self.dropout,
            name=f'up_{i_level}.block_{i_block}')(
                jnp.concatenate([h, hs.pop()], axis=-1),
                temb=temb, y=y, deterministic=not train)
        if h.shape[1] in self.attn_resolutions:
          h = AttnBlock(
              num_heads=self.num_heads,
              name=f'up_{i_level}.attn_{i_block}')(h)
      # Upsample
      if i_level != 0:
        h = self._upsample(h, name=f'up_{i_level}.upsample')
    assert not hs

    # End.
    h = nonlinearity(Normalize(name='norm_out')(h))

    if self.model_output == 'logistic_pars':
      # The output represents logits or the log scale and loc of a
      # logistic distribution.
      h = nn.Conv(
          features=self.out_ch * 2,
          kernel_size=(3, 3),
          strides=(1, 1),
          kernel_init=nn.initializers.zeros,
          name='conv_out')(
              h)
      loc, log_scale = jnp.split(h, 2, axis=-1)

      # ensure loc is between [-1, 1], just like normalized data.
      loc = jnp.tanh(loc + x)
      return loc, log_scale

    elif self.model_output == 'logits':
      h = nn.Conv(
          features=self.out_ch * self.num_pixel_vals,
          kernel_size=(3, 3),
          strides=(1, 1),
          kernel_init=nn.initializers.zeros,
          name='conv_out')(
              h)
      h = jnp.reshape(h, (*x.shape[:3], self.out_ch, self.num_pixel_vals))
      return x_onehot + h

    else:
      raise ValueError(
          f'self.model_output = {self.model_output} but must be '
          'logits or logistic_pars')
예제 #11
0
  def __call__(self, x, logsnr, y, *, train):
    B, H, W, _ = x.shape  # pylint: disable=invalid-name
    assert H == W
    assert x.dtype in (jnp.float32, jnp.float64)
    assert logsnr.shape == (B,) and logsnr.dtype in (jnp.float32, jnp.float64)
    num_resolutions = len(self.ch_mult)
    ch = self.ch
    emb_ch = self.emb_ch

    # Timestep embedding
    if self.logsnr_input_type == 'linear':
      logging.info('LogSNR representation: linear')
      logsnr_input = (logsnr - self.logsnr_scale_range[0]) / (
          self.logsnr_scale_range[1] - self.logsnr_scale_range[0])
    elif self.logsnr_input_type == 'sigmoid':
      logging.info('LogSNR representation: sigmoid')
      logsnr_input = nn.sigmoid(logsnr)
    elif self.logsnr_input_type == 'inv_cos':
      logging.info('LogSNR representation: inverse cosine')
      logsnr_input = (jnp.arctan(jnp.exp(-0.5 * jnp.clip(logsnr, -20., 20.)))
                      / (0.5 * jnp.pi))
    else:
      raise NotImplementedError(self.logsnr_input_type)

    emb = get_timestep_embedding(logsnr_input, embedding_dim=ch, max_time=1.)
    emb = nn.Dense(features=emb_ch, name='dense0')(emb)
    emb = nn.Dense(features=emb_ch, name='dense1')(nonlinearity(emb))
    assert emb.shape == (B, emb_ch)

    # Class embedding
    assert self.num_classes >= 1
    if self.num_classes > 1:
      logging.info('conditional: num_classes=%d', self.num_classes)
      assert y.shape == (B,) and y.dtype == jnp.int32
      y_emb = jax.nn.one_hot(y, num_classes=self.num_classes, dtype=x.dtype)
      y_emb = nn.Dense(features=emb_ch, name='class_emb')(y_emb)
      assert y_emb.shape == emb.shape == (B, emb_ch)
      emb += y_emb
    else:
      logging.info('unconditional: num_classes=%d', self.num_classes)
    del y

    # Downsampling
    hs = [nn.Conv(
        features=ch, kernel_size=(3, 3), strides=(1, 1), name='conv_in')(x)]
    for i_level in range(num_resolutions):
      # Residual blocks for this resolution
      for i_block in range(self.num_res_blocks):
        h = ResnetBlock(
            out_ch=ch * self.ch_mult[i_level],
            dropout=self.dropout,
            name=f'down_{i_level}.block_{i_block}')(
                hs[-1], emb=emb, deterministic=not train)
        if h.shape[1] in self.attn_resolutions:
          h = AttnBlock(
              num_heads=self.num_heads,
              head_dim=self.head_dim,
              name=f'down_{i_level}.attn_{i_block}')(h)
        hs.append(h)
      # Downsample
      if i_level != num_resolutions - 1:
        hs.append(self._downsample(
            hs[-1], name=f'down_{i_level}.downsample', emb=emb, train=train))

    # Middle
    h = hs[-1]
    h = ResnetBlock(dropout=self.dropout, name='mid.block_1')(
        h, emb=emb, deterministic=not train)
    h = AttnBlock(
        num_heads=self.num_heads, head_dim=self.head_dim, name='mid.attn_1')(h)
    h = ResnetBlock(dropout=self.dropout, name='mid.block_2')(
        h, emb=emb, deterministic=not train)

    # Upsampling
    for i_level in reversed(range(num_resolutions)):
      # Residual blocks for this resolution
      for i_block in range(self.num_res_blocks + 1):
        h = ResnetBlock(
            out_ch=ch * self.ch_mult[i_level],
            dropout=self.dropout,
            name=f'up_{i_level}.block_{i_block}')(
                jnp.concatenate([h, hs.pop()], axis=-1),
                emb=emb, deterministic=not train)
        if h.shape[1] in self.attn_resolutions:
          h = AttnBlock(
              num_heads=self.num_heads,
              head_dim=self.head_dim,
              name=f'up_{i_level}.attn_{i_block}')(h)
      # Upsample
      if i_level != 0:
        h = self._upsample(
            h, name=f'up_{i_level}.upsample', emb=emb, train=train)
    assert not hs

    # End
    h = nonlinearity(Normalize(name='norm_out')(h))
    h = nn.Conv(
        features=self.out_ch,
        kernel_size=(3, 3),
        strides=(1, 1),
        kernel_init=nn.initializers.zeros,
        name='conv_out')(h)
    assert h.shape == (*x.shape[:3], self.out_ch)
    return h
예제 #12
0
 def __call__(self, x):
     x = nn.Dense(features=alpha * x.shape[-1], dtype=jax.numpy.float32)(x)
     x = jnp.log(jnp.cosh(x))
     return jnp.sum(x, axis=-1)
예제 #13
0
 def __call__(self, x):
   counter = self.variable('counter', 'i', jnp.zeros, ())
   counter.value += 1
   x = nn.Dense(1)(x)
   return x
예제 #14
0
 def setup(self):
   self.dense_out = nn.Dense(self.n_out)
예제 #15
0
 def setup(self):
     self.bar = nn.Dense(3)
예제 #16
0
 def __call__(self, x):
     return nn.Dense(1)(x)
 def setup(self):
     self.bert = FlaxBertModule(config=self.config,
                                dtype=self.dtype,
                                add_pooling_layer=False)
     self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
     self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
예제 #18
0
 def __call__(self, x):
     return nn.Dense(1, parent=self)(x)
 def setup(self):
     self.transform = FlaxBertPredictionHeadTransform(self.config,
                                                      dtype=self.dtype)
     self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype)
예제 #20
0
 def __call__(self, x):
     for width in self.widths[:-1]:
         x = nn.relu(nn.Dense(width)(x))
     return nn.Dense(self.widths[-1])(x)
예제 #21
0
  def __call__(self,
               encoded,
               targets,
               targets_positions=None,
               decoder_mask=None,
               encoder_decoder_mask=None):
    """Applies Transformer model on the inputs.

    Args:
      encoded: encoded input data from encoder.
      targets: target inputs.
      targets_positions: input subsequence positions for packed examples.
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

    Returns:
      output of a transformer decoder.
    """
    cfg = self.config

    assert encoded.ndim == 3  # (batch, len, depth)
    assert targets.ndim == 2  # (batch, len)

    # Target Embedding
    if self.shared_embedding is None:
      output_embed = nn.Embed(
          num_embeddings=cfg.output_vocab_size,
          features=cfg.emb_dim,
          embedding_init=nn.initializers.normal(stddev=1.0))
    else:
      output_embed = self.shared_embedding

    y = targets.astype('int32')
    if not cfg.decode:
      y = shift_right(y)
    y = output_embed(y)
    y = AddPositionEmbs(config=cfg, decode=cfg.decode, name='posembed_output')(
        y, inputs_positions=targets_positions)
    y = nn.Dropout(rate=cfg.dropout_rate)(
        y, deterministic=cfg.deterministic)

    y = y.astype(cfg.dtype)

    # Target-Input Decoder
    for lyr in range(cfg.num_layers):
      y = EncoderDecoder1DBlock(
          config=cfg, name=f'encoderdecoderblock_{lyr}')(
              y,
              encoded,
              decoder_mask=decoder_mask,
              encoder_decoder_mask=encoder_decoder_mask)
    y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y)

    # Decoded Logits
    if cfg.logits_via_embedding:
      # Use the transpose of embedding matrix for logit transform.
      logits = output_embed.attend(y.astype(jnp.float32))
      # Correctly normalize pre-softmax logits for this shared case.
      logits = logits / jnp.sqrt(y.shape[-1])
    else:
      logits = nn.Dense(
          cfg.output_vocab_size,
          dtype=cfg.dtype,
          kernel_init=cfg.kernel_init,
          bias_init=cfg.bias_init,
          name='logitdense')(y)
    return logits
예제 #22
0
 def setup(self):
     self.layers = [nn.Dense(10), nn.relu, nn.Dense(10)]
 def setup(self):
     self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
     self.activation = ACT2FN[self.config.hidden_act]
     self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
예제 #24
0
 def test_module_is_hashable(self):
     module_a = nn.Dense(10)
     module_a_2 = nn.Dense(10)
     module_b = nn.Dense(5)
     self.assertEqual(hash(module_a), hash(module_a_2))
     self.assertNotEqual(hash(module_a), hash(module_b))
 def setup(self):
     self.seq_relationship = nn.Dense(2, dtype=self.dtype)
예제 #26
0
 def test_module_with_scope_is_not_hashable(self):
     module_a = nn.Dense(10, parent=Scope({}))
     with self.assertRaisesWithLiteralMatch(
             ValueError,
             'Can\'t call __hash__ on modules that hold variables.'):
         hash(module_a)
 def setup(self):
     self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
     self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
     self.classifier = nn.Dense(1, dtype=self.dtype)
예제 #28
0
 def __call__(self, x):
     for size in self.sizes:
         x = nn.Dense(size)(x)
         x = self.act(x)
     return repr(self)
예제 #29
0
 def setup(self):
     self.albert = FlaxAlbertModule(config=self.config,
                                    dtype=self.dtype,
                                    add_pooling_layer=False)
     self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
예제 #30
0
 def setup(self):
   self.b = B(nn.Dense(2))