def __init__(self, dim, window_size, shifts=None, input_resolution=None, ff_dim=None, use_pre_norm=False, **kwargs): super().__init__() self.attention_block = LayerNorm( ShiftWindowAttention(dim=dim, window_size=window_size, shifts=shifts, input_resolution=input_resolution, **kwargs) if shifts is not None else WindowAttention(dim=dim, window_size=window_size, **kwargs), dim=dim, use_pre_norm=use_pre_norm) self.attention_path_dropout = PathDropout( kwargs["path_dropout"] if "path_dropout" in kwargs else 0.) self.ff_block = LayerNorm(FeedForward( dim=dim, expand_dim=ff_dim if ff_dim is not None else 4 * dim, **kwargs), dim=dim, use_pre_norm=use_pre_norm) self.ff_path_dropout = PathDropout( kwargs["path_dropout"] if "path_dropout" in kwargs else 0.)
def __init__(self, *, dim, heads, num_seeds=1, attention_dropout=0.0, ff_expand_scale=4, pre_norm=False, head_dim=None, **kwargs): super().__init__() self.dim = dim self.heads = heads self.num_seeds = num_seeds assert (self.num_seeds > 0), "Number of seeds must be greater than zero." self.seeds = torch.nn.Parameter(torch.zeros(self.num_seeds, self.dim)) torch.nn.init.kaiming_normal_(self.seeds) self.MAB = MAB(dim=self.dim, heads=self.heads, head_dim=head_dim, attention_dropout=attention_dropout, ff_expand_scale=ff_expand_scale, pre_norm=pre_norm, **kwargs) self.ff = FeedForward(dim, expand_dim=dim, **kwargs)
def __init__( self, dim: int, ff_expand_scale: int = 4, path_dropout: float = 0., conv_position_encoder: Optional[nn.Module] = None, use_cls: bool = True, **kwargs, ) -> None: super().__init__() self.conv_position_encoder = ConvolutionalPositionEncoding( dim, use_cls=use_cls ) if conv_position_encoder is None else conv_position_encoder self.norm_0 = nn.LayerNorm(dim) self.conv_attn_module = ConvAttentionalModule( dim, use_cls=use_cls, use_conv_position_encoder=False, conv_position_encoder=None, **kwargs, ) self.path_dropout_0 = PathDropout(path_dropout) self.norm_1 = nn.LayerNorm(dim) self.ff_block = FeedForward( dim, ff_expand_scale=ff_expand_scale, ff_dropout=kwargs["ff_dropout"], ) self.path_dropout_1 = PathDropout(path_dropout)
def __init__(self, *, cross_kv_dim, cross_heads, dim, heads, latent_transformer_depth=1, pre_norm=False, ff_expand_scale=4, **kwargs): super().__init__() if cross_heads > 1: warnings.warn( f"[{self.__class__.__name__}] `cross_heads` is set to {cross_heads}, but its 1 in the original paper." ) self.cross_attention_block = nn.ModuleList([ LayerNorm( dim=dim, cross_dim=cross_kv_dim, use_pre_norm=pre_norm, use_cross_attention=True, fn=Residual(fn=MultiheadAttention( dim, kv_dim=cross_kv_dim, heads=cross_heads, **kwargs))), LayerNorm(dim=dim, use_pre_norm=pre_norm, fn=Residual(fn=FeedForward( dim, expand_dim=ff_expand_scale * dim, **kwargs))) ]) self.latent_transformers = nn.ModuleList([ nn.ModuleList([ LayerNorm( fn=Residual( fn=MultiheadAttention(dim, heads=heads, **kwargs)), dim=dim, use_pre_norm=pre_norm, ), LayerNorm( fn=Residual(fn=FeedForward( dim, expand_dim=ff_expand_scale * dim, **kwargs)), dim=dim, use_pre_norm=pre_norm, ) ]) for _ in range(latent_transformer_depth) ])
def __init__(self, *, dim, pre_norm=False, ff_dim=None, **kwargs): super().__init__() self.fourier_block = LayerNorm(Residual(FNetFourierTransform()), dim=dim, use_pre_norm=pre_norm) self.ff_block = LayerNorm(Residual( FeedForward(dim=dim, expand_dim=ff_dim if ff_dim is not None else 4 * dim, **kwargs)), dim=dim, use_pre_norm=pre_norm)
def __init__(self, dim, pre_norm=False, ff_dim=None, **kwargs): super().__init__() self.attention_block = LayerNorm(Residual( GatedPositionalSelfAttention(dim, **kwargs)), dim=dim, use_pre_norm=pre_norm) self.ff_block = LayerNorm(Residual( FeedForward(dim=dim, expand_dim=ff_dim if ff_dim is not None else 4 * dim, **kwargs)), dim=dim, use_pre_norm=pre_norm)
def __init__(self, dim, heads=None, head_dim=None, pre_norm=False, ff_dim=None, **kwargs): super().__init__() self.attention_block = LayerNorm(Residual( MultiheadAttention(dim, heads=heads, head_dim=head_dim, **kwargs)), dim=dim, use_pre_norm=pre_norm) self.ff_block = LayerNorm(Residual( FeedForward(dim, expand_dim=ff_dim if ff_dim is not None else 4 * dim, **kwargs)), dim=dim, use_pre_norm=pre_norm)
def __init__(self, *, dim, heads, ff_expand_scale=4, pre_norm=False, **kwargs): super().__init__() self.attention = LayerNorm( dim=dim, use_pre_norm=pre_norm, fn=Residual(fn=MultiheadAttention(dim, heads=heads, **kwargs), ), ) self.ff = LayerNorm( dim=dim, use_pre_norm=pre_norm, fn=Residual( fn=FeedForward(dim, expand_dim=ff_expand_scale * dim, **kwargs)), )
def __init__(self, **kwargs): super().__init__() self.PMA = PMA(**kwargs) self.SAB = SAB(**kwargs) self.ff = FeedForward(**kwargs)