示例#1
0
    def __init__(self,
                 dim,
                 window_size,
                 shifts=None,
                 input_resolution=None,
                 ff_dim=None,
                 use_pre_norm=False,
                 **kwargs):
        super().__init__()

        self.attention_block = LayerNorm(
            ShiftWindowAttention(dim=dim,
                                 window_size=window_size,
                                 shifts=shifts,
                                 input_resolution=input_resolution,
                                 **kwargs) if shifts is not None else
            WindowAttention(dim=dim, window_size=window_size, **kwargs),
            dim=dim,
            use_pre_norm=use_pre_norm)
        self.attention_path_dropout = PathDropout(
            kwargs["path_dropout"] if "path_dropout" in kwargs else 0.)

        self.ff_block = LayerNorm(FeedForward(
            dim=dim,
            expand_dim=ff_dim if ff_dim is not None else 4 * dim,
            **kwargs),
                                  dim=dim,
                                  use_pre_norm=use_pre_norm)
        self.ff_path_dropout = PathDropout(
            kwargs["path_dropout"] if "path_dropout" in kwargs else 0.)
    def __init__(self,
                 *,
                 dim,
                 heads,
                 num_seeds=1,
                 attention_dropout=0.0,
                 ff_expand_scale=4,
                 pre_norm=False,
                 head_dim=None,
                 **kwargs):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.num_seeds = num_seeds
        assert (self.num_seeds >
                0), "Number of seeds must be greater than zero."

        self.seeds = torch.nn.Parameter(torch.zeros(self.num_seeds, self.dim))
        torch.nn.init.kaiming_normal_(self.seeds)

        self.MAB = MAB(dim=self.dim,
                       heads=self.heads,
                       head_dim=head_dim,
                       attention_dropout=attention_dropout,
                       ff_expand_scale=ff_expand_scale,
                       pre_norm=pre_norm,
                       **kwargs)
        self.ff = FeedForward(dim, expand_dim=dim, **kwargs)
示例#3
0
    def __init__(
        self,
        dim: int,
        ff_expand_scale: int = 4,
        path_dropout: float = 0.,
        conv_position_encoder: Optional[nn.Module] = None,
        use_cls: bool = True,
        **kwargs,
    ) -> None:
        super().__init__()

        self.conv_position_encoder = ConvolutionalPositionEncoding(
            dim, use_cls=use_cls
        ) if conv_position_encoder is None else conv_position_encoder

        self.norm_0 = nn.LayerNorm(dim)
        self.conv_attn_module = ConvAttentionalModule(
            dim,
            use_cls=use_cls,
            use_conv_position_encoder=False,
            conv_position_encoder=None,
            **kwargs,
        )
        self.path_dropout_0 = PathDropout(path_dropout)

        self.norm_1 = nn.LayerNorm(dim)
        self.ff_block = FeedForward(
            dim,
            ff_expand_scale=ff_expand_scale,
            ff_dropout=kwargs["ff_dropout"],
        )
        self.path_dropout_1 = PathDropout(path_dropout)
    def __init__(self,
                 *,
                 cross_kv_dim,
                 cross_heads,
                 dim,
                 heads,
                 latent_transformer_depth=1,
                 pre_norm=False,
                 ff_expand_scale=4,
                 **kwargs):
        super().__init__()

        if cross_heads > 1:
            warnings.warn(
                f"[{self.__class__.__name__}] `cross_heads` is set to {cross_heads}, but its 1 in the original paper."
            )

        self.cross_attention_block = nn.ModuleList([
            LayerNorm(
                dim=dim,
                cross_dim=cross_kv_dim,
                use_pre_norm=pre_norm,
                use_cross_attention=True,
                fn=Residual(fn=MultiheadAttention(
                    dim, kv_dim=cross_kv_dim, heads=cross_heads, **kwargs))),
            LayerNorm(dim=dim,
                      use_pre_norm=pre_norm,
                      fn=Residual(fn=FeedForward(
                          dim, expand_dim=ff_expand_scale * dim, **kwargs)))
        ])

        self.latent_transformers = nn.ModuleList([
            nn.ModuleList([
                LayerNorm(
                    fn=Residual(
                        fn=MultiheadAttention(dim, heads=heads, **kwargs)),
                    dim=dim,
                    use_pre_norm=pre_norm,
                ),
                LayerNorm(
                    fn=Residual(fn=FeedForward(
                        dim, expand_dim=ff_expand_scale * dim, **kwargs)),
                    dim=dim,
                    use_pre_norm=pre_norm,
                )
            ]) for _ in range(latent_transformer_depth)
        ])
示例#5
0
    def __init__(self, *, dim, pre_norm=False, ff_dim=None, **kwargs):
        super().__init__()

        self.fourier_block = LayerNorm(Residual(FNetFourierTransform()),
                                       dim=dim,
                                       use_pre_norm=pre_norm)
        self.ff_block = LayerNorm(Residual(
            FeedForward(dim=dim,
                        expand_dim=ff_dim if ff_dim is not None else 4 * dim,
                        **kwargs)),
                                  dim=dim,
                                  use_pre_norm=pre_norm)
    def __init__(self, dim, pre_norm=False, ff_dim=None, **kwargs):
        super().__init__()

        self.attention_block = LayerNorm(Residual(
            GatedPositionalSelfAttention(dim, **kwargs)),
                                         dim=dim,
                                         use_pre_norm=pre_norm)
        self.ff_block = LayerNorm(Residual(
            FeedForward(dim=dim,
                        expand_dim=ff_dim if ff_dim is not None else 4 * dim,
                        **kwargs)),
                                  dim=dim,
                                  use_pre_norm=pre_norm)
    def __init__(self,
                 dim,
                 heads=None,
                 head_dim=None,
                 pre_norm=False,
                 ff_dim=None,
                 **kwargs):
        super().__init__()

        self.attention_block = LayerNorm(Residual(
            MultiheadAttention(dim, heads=heads, head_dim=head_dim, **kwargs)),
                                         dim=dim,
                                         use_pre_norm=pre_norm)
        self.ff_block = LayerNorm(Residual(
            FeedForward(dim,
                        expand_dim=ff_dim if ff_dim is not None else 4 * dim,
                        **kwargs)),
                                  dim=dim,
                                  use_pre_norm=pre_norm)
    def __init__(self,
                 *,
                 dim,
                 heads,
                 ff_expand_scale=4,
                 pre_norm=False,
                 **kwargs):
        super().__init__()

        self.attention = LayerNorm(
            dim=dim,
            use_pre_norm=pre_norm,
            fn=Residual(fn=MultiheadAttention(dim, heads=heads, **kwargs), ),
        )
        self.ff = LayerNorm(
            dim=dim,
            use_pre_norm=pre_norm,
            fn=Residual(
                fn=FeedForward(dim, expand_dim=ff_expand_scale *
                               dim, **kwargs)),
        )
    def __init__(self, **kwargs):
        super().__init__()

        self.PMA = PMA(**kwargs)
        self.SAB = SAB(**kwargs)
        self.ff = FeedForward(**kwargs)