def __init__(self, *, dim, pre_norm=False, ff_dim=None, **kwargs): super().__init__() self.fourier_block = LayerNorm(Residual(FNetFourierTransform()), dim=dim, use_pre_norm=pre_norm) self.ff_block = LayerNorm(Residual( FeedForward(dim=dim, expand_dim=ff_dim if ff_dim is not None else 4 * dim, **kwargs)), dim=dim, use_pre_norm=pre_norm)
def __init__(self, dim, pre_norm=False, ff_dim=None, **kwargs): super().__init__() self.attention_block = LayerNorm(Residual( GatedPositionalSelfAttention(dim, **kwargs)), dim=dim, use_pre_norm=pre_norm) self.ff_block = LayerNorm(Residual( FeedForward(dim=dim, expand_dim=ff_dim if ff_dim is not None else 4 * dim, **kwargs)), dim=dim, use_pre_norm=pre_norm)
def __init__(self, *, cross_kv_dim, cross_heads, dim, heads, latent_transformer_depth=1, pre_norm=False, ff_expand_scale=4, **kwargs): super().__init__() if cross_heads > 1: warnings.warn( f"[{self.__class__.__name__}] `cross_heads` is set to {cross_heads}, but its 1 in the original paper." ) self.cross_attention_block = nn.ModuleList([ LayerNorm( dim=dim, cross_dim=cross_kv_dim, use_pre_norm=pre_norm, use_cross_attention=True, fn=Residual(fn=MultiheadAttention( dim, kv_dim=cross_kv_dim, heads=cross_heads, **kwargs))), LayerNorm(dim=dim, use_pre_norm=pre_norm, fn=Residual(fn=FeedForward( dim, expand_dim=ff_expand_scale * dim, **kwargs))) ]) self.latent_transformers = nn.ModuleList([ nn.ModuleList([ LayerNorm( fn=Residual( fn=MultiheadAttention(dim, heads=heads, **kwargs)), dim=dim, use_pre_norm=pre_norm, ), LayerNorm( fn=Residual(fn=FeedForward( dim, expand_dim=ff_expand_scale * dim, **kwargs)), dim=dim, use_pre_norm=pre_norm, ) ]) for _ in range(latent_transformer_depth) ])
def __init__(self, dim, heads=None, head_dim=None, pre_norm=False, ff_dim=None, **kwargs): super().__init__() self.attention_block = LayerNorm(Residual( MultiheadAttention(dim, heads=heads, head_dim=head_dim, **kwargs)), dim=dim, use_pre_norm=pre_norm) self.ff_block = LayerNorm(Residual( FeedForward(dim, expand_dim=ff_dim if ff_dim is not None else 4 * dim, **kwargs)), dim=dim, use_pre_norm=pre_norm)
def __init__(self, dim, num_tokens, attention_dim=None, **kwargs): super().__init__() self.norm = nn.LayerNorm(dim // 2) self.spatial_proj = nn.Conv1d(num_tokens, num_tokens, 1) nn.init.zeros_(self.spatial_proj.weight) nn.init.ones_(self.spatial_proj.bias) self.attention = Residual( MultiheadAttention(dim // 2, 1, proj_dim=attention_dim, ** kwargs)) if attention_dim is not None else None
def __init__(self, *, dim, heads, ff_expand_scale=4, pre_norm=False, **kwargs): super().__init__() self.attention = LayerNorm( dim=dim, use_pre_norm=pre_norm, fn=Residual(fn=MultiheadAttention(dim, heads=heads, **kwargs), ), ) self.ff = LayerNorm( dim=dim, use_pre_norm=pre_norm, fn=Residual( fn=FeedForward(dim, expand_dim=ff_expand_scale * dim, **kwargs)), )
def __init__(self, dim, ffn_dim, num_tokens, *, attention_dim=None, **kwargs): super().__init__() self.net = Residual( nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, ffn_dim), nn.GELU(), SpatialGatingUnit(ffn_dim, num_tokens, attention_dim, **kwargs), nn.Linear(ffn_dim // 2, dim)))