def setup(self): self.layers = [nn.Dense(self.ch) for _ in range(self.n_layers)]
def setup(self): self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) self.seq_relationship = nn.Dense(2, dtype=self.dtype)
def __call__(self, inputs, train, inputs_positions=None, inputs_segmentation=None): """Applies Transformer model on the inputs. Args: inputs: input data train: bool: if model is training. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: output of a transformer decoder. """ assert inputs.ndim == 2 # (batch, len) dtype = utils.dtype_from_str(self.model_dtype) if self.decode: # for fast autoregressive decoding we use no decoder mask decoder_mask = None else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype), nn.make_causal_mask(inputs, dtype=dtype)) if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) y = inputs.astype('int32') if not self.decode: y = shift_inputs(y, segment_ids=inputs_segmentation) # TODO(gdahl,znado): this code appears to be accessing out-of-bounds # indices for dataset_lib:proteins_test. This will break when jnp.take() is # updated to return NaNs for out-of-bounds indices. # Debug why this is the case. y = jnp.clip(y, 0, self.vocab_size - 1) if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: output_embed = self.shared_embedding y = output_embed(y) y = AddPositionEmbs(max_len=self.max_len, posemb_init=sinusoidal_init(max_len=self.max_len), decode=self.decode, name='posembed_output')( y, inputs_positions=inputs_positions) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y.astype(dtype) for _ in range(self.num_layers): y = Transformer1DBlock( qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, attention_fn=self.attention_fn, normalizer=self.normalizer, dtype=dtype)( inputs=y, train=train, decoder_mask=decoder_mask, encoder_decoder_mask=None, inputs_positions=None, inputs_segmentation=None, ) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer(self.normalizer, train, dtype=dtype) y = maybe_normalize()(y) if self.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense(self.vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), dtype=dtype, name='logits_dense')(y) return logits.astype(dtype)
def setup(self): self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), dtype=self.dtype, )
def setup(self): self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype) self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
def setup(self): self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) self.activation = ACT2FN[self.config.hidden_act] self.LayerNorm = FlaxBertLayerNorm(hidden_size=self.config.hidden_size, dtype=self.dtype)
def embed(inputs): return nn.Dense(latent_size)(inputs)
def setup(self): self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def setup(self): self.roformer = FlaxRoFormerModule(config=self.config, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(1, dtype=self.dtype)
def __call__(self, x, t, y, *, train): assert x.dtype == jnp.int32 x_onehot = jax.nn.one_hot(x, num_classes=self.num_pixel_vals) # Convert to float and scale image to [-1, 1] x = utils.normalize_data(x.astype(jnp.float32)) batch_size, height, width, _ = x.shape assert height == width assert x.dtype in (jnp.float32, jnp.float64) assert t.shape == (batch_size,) # and t.dtype == jnp.int32 num_resolutions = len(self.ch_mult) ch = self.ch # Class embedding assert self.num_classes >= 1 if self.num_classes > 1: logging.info('conditional: num_classes=%d', self.num_classes) assert y.shape == (batch_size,) and y.dtype == jnp.int32 y = jax.nn.one_hot(y, num_classes=self.num_classes, dtype=x.dtype) y = nn.Dense(features=ch * 4, name='class_emb')(y) assert y.shape == (batch_size, ch * 4) else: logging.info('unconditional: num_classes=%d', self.num_classes) y = None # Timestep embedding logging.info('model max_time: %f', self.max_time) temb = get_timestep_embedding(t, ch, max_time=self.max_time) temb = nn.Dense(features=ch * 4, name='dense0')(temb) temb = nn.Dense(features=ch * 4, name='dense1')(nonlinearity(temb)) assert temb.shape == (batch_size, ch * 4) # Downsampling hs = [nn.Conv( features=ch, kernel_size=(3, 3), strides=(1, 1), name='conv_in')(x)] for i_level in range(num_resolutions): # Residual blocks for this resolution for i_block in range(self.num_res_blocks): h = ResnetBlock( out_ch=ch * self.ch_mult[i_level], dropout=self.dropout, name=f'down_{i_level}.block_{i_block}')( hs[-1], temb=temb, y=y, deterministic=not train) if h.shape[1] in self.attn_resolutions: h = AttnBlock( num_heads=self.num_heads, name=f'down_{i_level}.attn_{i_block}')(h) hs.append(h) # Downsample if i_level != num_resolutions - 1: hs.append(self._downsample(hs[-1], name=f'down_{i_level}.downsample')) # Middle h = hs[-1] h = ResnetBlock(dropout=self.dropout, name='mid.block_1')( h, temb=temb, y=y, deterministic=not train) h = AttnBlock(num_heads=self.num_heads, name='mid.attn_1')(h) h = ResnetBlock(dropout=self.dropout, name='mid.block_2')( h, temb=temb, y=y, deterministic=not train) # Upsampling for i_level in reversed(range(num_resolutions)): # Residual blocks for this resolution for i_block in range(self.num_res_blocks + 1): h = ResnetBlock( out_ch=ch * self.ch_mult[i_level], dropout=self.dropout, name=f'up_{i_level}.block_{i_block}')( jnp.concatenate([h, hs.pop()], axis=-1), temb=temb, y=y, deterministic=not train) if h.shape[1] in self.attn_resolutions: h = AttnBlock( num_heads=self.num_heads, name=f'up_{i_level}.attn_{i_block}')(h) # Upsample if i_level != 0: h = self._upsample(h, name=f'up_{i_level}.upsample') assert not hs # End. h = nonlinearity(Normalize(name='norm_out')(h)) if self.model_output == 'logistic_pars': # The output represents logits or the log scale and loc of a # logistic distribution. h = nn.Conv( features=self.out_ch * 2, kernel_size=(3, 3), strides=(1, 1), kernel_init=nn.initializers.zeros, name='conv_out')( h) loc, log_scale = jnp.split(h, 2, axis=-1) # ensure loc is between [-1, 1], just like normalized data. loc = jnp.tanh(loc + x) return loc, log_scale elif self.model_output == 'logits': h = nn.Conv( features=self.out_ch * self.num_pixel_vals, kernel_size=(3, 3), strides=(1, 1), kernel_init=nn.initializers.zeros, name='conv_out')( h) h = jnp.reshape(h, (*x.shape[:3], self.out_ch, self.num_pixel_vals)) return x_onehot + h else: raise ValueError( f'self.model_output = {self.model_output} but must be ' 'logits or logistic_pars')
def __call__(self, x, logsnr, y, *, train): B, H, W, _ = x.shape # pylint: disable=invalid-name assert H == W assert x.dtype in (jnp.float32, jnp.float64) assert logsnr.shape == (B,) and logsnr.dtype in (jnp.float32, jnp.float64) num_resolutions = len(self.ch_mult) ch = self.ch emb_ch = self.emb_ch # Timestep embedding if self.logsnr_input_type == 'linear': logging.info('LogSNR representation: linear') logsnr_input = (logsnr - self.logsnr_scale_range[0]) / ( self.logsnr_scale_range[1] - self.logsnr_scale_range[0]) elif self.logsnr_input_type == 'sigmoid': logging.info('LogSNR representation: sigmoid') logsnr_input = nn.sigmoid(logsnr) elif self.logsnr_input_type == 'inv_cos': logging.info('LogSNR representation: inverse cosine') logsnr_input = (jnp.arctan(jnp.exp(-0.5 * jnp.clip(logsnr, -20., 20.))) / (0.5 * jnp.pi)) else: raise NotImplementedError(self.logsnr_input_type) emb = get_timestep_embedding(logsnr_input, embedding_dim=ch, max_time=1.) emb = nn.Dense(features=emb_ch, name='dense0')(emb) emb = nn.Dense(features=emb_ch, name='dense1')(nonlinearity(emb)) assert emb.shape == (B, emb_ch) # Class embedding assert self.num_classes >= 1 if self.num_classes > 1: logging.info('conditional: num_classes=%d', self.num_classes) assert y.shape == (B,) and y.dtype == jnp.int32 y_emb = jax.nn.one_hot(y, num_classes=self.num_classes, dtype=x.dtype) y_emb = nn.Dense(features=emb_ch, name='class_emb')(y_emb) assert y_emb.shape == emb.shape == (B, emb_ch) emb += y_emb else: logging.info('unconditional: num_classes=%d', self.num_classes) del y # Downsampling hs = [nn.Conv( features=ch, kernel_size=(3, 3), strides=(1, 1), name='conv_in')(x)] for i_level in range(num_resolutions): # Residual blocks for this resolution for i_block in range(self.num_res_blocks): h = ResnetBlock( out_ch=ch * self.ch_mult[i_level], dropout=self.dropout, name=f'down_{i_level}.block_{i_block}')( hs[-1], emb=emb, deterministic=not train) if h.shape[1] in self.attn_resolutions: h = AttnBlock( num_heads=self.num_heads, head_dim=self.head_dim, name=f'down_{i_level}.attn_{i_block}')(h) hs.append(h) # Downsample if i_level != num_resolutions - 1: hs.append(self._downsample( hs[-1], name=f'down_{i_level}.downsample', emb=emb, train=train)) # Middle h = hs[-1] h = ResnetBlock(dropout=self.dropout, name='mid.block_1')( h, emb=emb, deterministic=not train) h = AttnBlock( num_heads=self.num_heads, head_dim=self.head_dim, name='mid.attn_1')(h) h = ResnetBlock(dropout=self.dropout, name='mid.block_2')( h, emb=emb, deterministic=not train) # Upsampling for i_level in reversed(range(num_resolutions)): # Residual blocks for this resolution for i_block in range(self.num_res_blocks + 1): h = ResnetBlock( out_ch=ch * self.ch_mult[i_level], dropout=self.dropout, name=f'up_{i_level}.block_{i_block}')( jnp.concatenate([h, hs.pop()], axis=-1), emb=emb, deterministic=not train) if h.shape[1] in self.attn_resolutions: h = AttnBlock( num_heads=self.num_heads, head_dim=self.head_dim, name=f'up_{i_level}.attn_{i_block}')(h) # Upsample if i_level != 0: h = self._upsample( h, name=f'up_{i_level}.upsample', emb=emb, train=train) assert not hs # End h = nonlinearity(Normalize(name='norm_out')(h)) h = nn.Conv( features=self.out_ch, kernel_size=(3, 3), strides=(1, 1), kernel_init=nn.initializers.zeros, name='conv_out')(h) assert h.shape == (*x.shape[:3], self.out_ch) return h
def __call__(self, x): x = nn.Dense(features=alpha * x.shape[-1], dtype=jax.numpy.float32)(x) x = jnp.log(jnp.cosh(x)) return jnp.sum(x, axis=-1)
def __call__(self, x): counter = self.variable('counter', 'i', jnp.zeros, ()) counter.value += 1 x = nn.Dense(1)(x) return x
def setup(self): self.dense_out = nn.Dense(self.n_out)
def setup(self): self.bar = nn.Dense(3)
def __call__(self, x): return nn.Dense(1)(x)
def setup(self): self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(self, x): return nn.Dense(1, parent=self)(x)
def setup(self): self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype) self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype)
def __call__(self, x): for width in self.widths[:-1]: x = nn.relu(nn.Dense(width)(x)) return nn.Dense(self.widths[-1])(x)
def __call__(self, encoded, targets, targets_positions=None, decoder_mask=None, encoder_decoder_mask=None): """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. targets: target inputs. targets_positions: input subsequence positions for packed examples. decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. Returns: output of a transformer decoder. """ cfg = self.config assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) # Target Embedding if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=cfg.output_vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: output_embed = self.shared_embedding y = targets.astype('int32') if not cfg.decode: y = shift_right(y) y = output_embed(y) y = AddPositionEmbs(config=cfg, decode=cfg.decode, name='posembed_output')( y, inputs_positions=targets_positions) y = nn.Dropout(rate=cfg.dropout_rate)( y, deterministic=cfg.deterministic) y = y.astype(cfg.dtype) # Target-Input Decoder for lyr in range(cfg.num_layers): y = EncoderDecoder1DBlock( config=cfg, name=f'encoderdecoderblock_{lyr}')( y, encoded, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask) y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y) # Decoded Logits if cfg.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense( cfg.output_vocab_size, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, name='logitdense')(y) return logits
def setup(self): self.layers = [nn.Dense(10), nn.relu, nn.Dense(10)]
def setup(self): self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) self.activation = ACT2FN[self.config.hidden_act] self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def test_module_is_hashable(self): module_a = nn.Dense(10) module_a_2 = nn.Dense(10) module_b = nn.Dense(5) self.assertEqual(hash(module_a), hash(module_a_2)) self.assertNotEqual(hash(module_a), hash(module_b))
def setup(self): self.seq_relationship = nn.Dense(2, dtype=self.dtype)
def test_module_with_scope_is_not_hashable(self): module_a = nn.Dense(10, parent=Scope({})) with self.assertRaisesWithLiteralMatch( ValueError, 'Can\'t call __hash__ on modules that hold variables.'): hash(module_a)
def setup(self): self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(1, dtype=self.dtype)
def __call__(self, x): for size in self.sizes: x = nn.Dense(size)(x) x = self.act(x) return repr(self)
def setup(self): self.albert = FlaxAlbertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
def setup(self): self.b = B(nn.Dense(2))