Esempio n. 1
0
    def __init__(self, *, dim, pre_norm=False, ff_dim=None, **kwargs):
        super().__init__()

        self.fourier_block = LayerNorm(Residual(FNetFourierTransform()),
                                       dim=dim,
                                       use_pre_norm=pre_norm)
        self.ff_block = LayerNorm(Residual(
            FeedForward(dim=dim,
                        expand_dim=ff_dim if ff_dim is not None else 4 * dim,
                        **kwargs)),
                                  dim=dim,
                                  use_pre_norm=pre_norm)
    def __init__(self, dim, pre_norm=False, ff_dim=None, **kwargs):
        super().__init__()

        self.attention_block = LayerNorm(Residual(
            GatedPositionalSelfAttention(dim, **kwargs)),
                                         dim=dim,
                                         use_pre_norm=pre_norm)
        self.ff_block = LayerNorm(Residual(
            FeedForward(dim=dim,
                        expand_dim=ff_dim if ff_dim is not None else 4 * dim,
                        **kwargs)),
                                  dim=dim,
                                  use_pre_norm=pre_norm)
    def __init__(self,
                 *,
                 cross_kv_dim,
                 cross_heads,
                 dim,
                 heads,
                 latent_transformer_depth=1,
                 pre_norm=False,
                 ff_expand_scale=4,
                 **kwargs):
        super().__init__()

        if cross_heads > 1:
            warnings.warn(
                f"[{self.__class__.__name__}] `cross_heads` is set to {cross_heads}, but its 1 in the original paper."
            )

        self.cross_attention_block = nn.ModuleList([
            LayerNorm(
                dim=dim,
                cross_dim=cross_kv_dim,
                use_pre_norm=pre_norm,
                use_cross_attention=True,
                fn=Residual(fn=MultiheadAttention(
                    dim, kv_dim=cross_kv_dim, heads=cross_heads, **kwargs))),
            LayerNorm(dim=dim,
                      use_pre_norm=pre_norm,
                      fn=Residual(fn=FeedForward(
                          dim, expand_dim=ff_expand_scale * dim, **kwargs)))
        ])

        self.latent_transformers = nn.ModuleList([
            nn.ModuleList([
                LayerNorm(
                    fn=Residual(
                        fn=MultiheadAttention(dim, heads=heads, **kwargs)),
                    dim=dim,
                    use_pre_norm=pre_norm,
                ),
                LayerNorm(
                    fn=Residual(fn=FeedForward(
                        dim, expand_dim=ff_expand_scale * dim, **kwargs)),
                    dim=dim,
                    use_pre_norm=pre_norm,
                )
            ]) for _ in range(latent_transformer_depth)
        ])
    def __init__(self,
                 dim,
                 heads=None,
                 head_dim=None,
                 pre_norm=False,
                 ff_dim=None,
                 **kwargs):
        super().__init__()

        self.attention_block = LayerNorm(Residual(
            MultiheadAttention(dim, heads=heads, head_dim=head_dim, **kwargs)),
                                         dim=dim,
                                         use_pre_norm=pre_norm)
        self.ff_block = LayerNorm(Residual(
            FeedForward(dim,
                        expand_dim=ff_dim if ff_dim is not None else 4 * dim,
                        **kwargs)),
                                  dim=dim,
                                  use_pre_norm=pre_norm)
    def __init__(self, dim, num_tokens, attention_dim=None, **kwargs):
        super().__init__()

        self.norm = nn.LayerNorm(dim // 2)
        self.spatial_proj = nn.Conv1d(num_tokens, num_tokens, 1)
        nn.init.zeros_(self.spatial_proj.weight)
        nn.init.ones_(self.spatial_proj.bias)

        self.attention = Residual(
            MultiheadAttention(dim // 2, 1, proj_dim=attention_dim, **
                               kwargs)) if attention_dim is not None else None
    def __init__(self,
                 *,
                 dim,
                 heads,
                 ff_expand_scale=4,
                 pre_norm=False,
                 **kwargs):
        super().__init__()

        self.attention = LayerNorm(
            dim=dim,
            use_pre_norm=pre_norm,
            fn=Residual(fn=MultiheadAttention(dim, heads=heads, **kwargs), ),
        )
        self.ff = LayerNorm(
            dim=dim,
            use_pre_norm=pre_norm,
            fn=Residual(
                fn=FeedForward(dim, expand_dim=ff_expand_scale *
                               dim, **kwargs)),
        )
    def __init__(self,
                 dim,
                 ffn_dim,
                 num_tokens,
                 *,
                 attention_dim=None,
                 **kwargs):
        super().__init__()

        self.net = Residual(
            nn.Sequential(
                nn.LayerNorm(dim), nn.Linear(dim, ffn_dim), nn.GELU(),
                SpatialGatingUnit(ffn_dim, num_tokens, attention_dim,
                                  **kwargs), nn.Linear(ffn_dim // 2, dim)))