Exemplo n.º 1
0
    def _show_batch(
        self,
        batch: List[torch.tensor],
        sample_length: int,
        mean: Tuple[int, int, int] = DEFAULT_MEAN,
        std: Tuple[int, int, int] = DEFAULT_STD,
    ) -> None:
        """
        Display a batch of images.

        Args:
            batch: List of sample (clip) tensors
            sample_length: Number of frames to show for each sample
            mean: Normalization mean
            std: Normalization std-dev
        """
        batch_size = len(batch)
        plt.tight_layout()
        fig, axs = plt.subplots(
            batch_size,
            sample_length,
            figsize=(4 * sample_length, 3 * batch_size),
        )

        for i, ax in enumerate(axs):
            if batch_size == 1:
                clip = batch[0]
            else:
                clip = batch[i]
            clip = Rearrange("c t h w -> t c h w")(clip)
            if not isinstance(ax, np.ndarray):
                ax = [ax]
            for j, a in enumerate(ax):
                a.axis("off")
                a.imshow(
                    np.moveaxis(denormalize(clip[j], mean, std).numpy(), 0, -1)
                )
            pass
Exemplo n.º 2
0
    def __init__(self,
                 *,
                 image_size,
                 patch_size,
                 num_classes,
                 dim,
                 depth,
                 ff_mult=4,
                 channels=3,
                 attn_dim=None,
                 prob_survival=1.):
        super().__init__()
        assert (
            image_size %
            patch_size) == 0, 'image size must be divisible by the patch size'
        dim_ff = dim * ff_mult
        num_patches = (image_size // patch_size)**2

        self.to_patch_embed = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)',
                      p1=patch_size,
                      p2=patch_size), nn.Linear(channels * patch_size**2, dim))

        self.prob_survival = prob_survival

        self.layers = nn.ModuleList([
            Residual(
                PreNorm(
                    dim,
                    gMLPBlock(dim=dim,
                              dim_ff=dim_ff,
                              seq_len=num_patches,
                              attn_dim=attn_dim))) for i in range(depth)
        ])

        self.to_logits = nn.Sequential(nn.LayerNorm(dim),
                                       Reduce('b n d -> b d', 'mean'),
                                       nn.Linear(dim, num_classes))
Exemplo n.º 3
0
    def __init__(
        self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        layers = []
        layer_dim = channels

        for i, (kernel_size, stride) in enumerate(t2t_layers):
            layer_dim *= kernel_size ** 2
            is_first = i == 0

            layers.extend([
                RearrangeImage() if not is_first else nn.Identity(),
                nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2),
                Rearrange('b c n -> b n c'),
                Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout),
            ])

        layers.append(nn.Linear(layer_dim, dim))
        self.to_patch_embedding = nn.Sequential(*layers)

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
Exemplo n.º 4
0
    def __init__(
        self,
        config: Config,
        forward_model: Optional[ForwardModel] = None,
    ) -> None:
        super().__init__()
        # self.save_hyperparameters()
        self.config = config
        # self.config["num_wavelens"] = len(
        #     torch.load(Path("/data-new/alok/laser/data.pt"))["interpolated_wavelength"][0]
        # )
        if forward_model is None:
            self.forward_model = None
        else:
            self.forward_model = forward_model
            self.forward_model.freeze()

        self.trunk = nn.Sequential(
            Rearrange("b c -> b c 1 1"),
            MLPMixer(
                in_channels=self.config["num_wavelens"],
                image_size=1,
                patch_size=1,
                num_classes=1_000,
                dim=512,
                depth=8,
                token_dim=256,
                channel_dim=2048,
                dropout=0.5,
            ),
            nn.Flatten(),
        )

        self.continuous_head = nn.LazyLinear(2)
        self.discrete_head = nn.LazyLinear(12)
        # XXX this call *must* happen to initialize the lazy layers
        _dummy_input = torch.rand(2, self.config["num_wavelens"])
        self.forward(_dummy_input)
    def __init__(self,
                 *,
                 image_size,
                 patch_size,
                 num_classes,
                 dim,
                 depth,
                 heads,
                 mlp_dim,
                 pool='cls',
                 channels=3,
                 dim_head=64,
                 dropout=0.,
                 emb_dropout=0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size)**2
        patch_dim = channels * patch_size**2
        assert pool in {
            'cls', 'mean'
        }, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                      p1=patch_size,
                      p2=patch_size),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim,
                                       dropout)

        self.pool = pool
        self.to_latent = nn.Identity()
Exemplo n.º 6
0
    def __init__(self,
                 dim,
                 heads,
                 row_attn=True,
                 col_attn=True,
                 accept_edges=False,
                 global_query_attn=False,
                 **kwargs):
        super().__init__()
        assert not (not row_attn and
                    not col_attn), 'row or column attention must be turned on'

        self.row_attn = row_attn
        self.col_attn = col_attn
        self.global_query_attn = global_query_attn

        self.norm = nn.LayerNorm(dim)

        self.attn = Attention(dim=dim, heads=heads, **kwargs)

        self.edges_to_attn_bias = nn.Sequential(
            nn.Linear(dim, heads, bias=False),
            Rearrange('b i j h -> b h i j')) if accept_edges else None
Exemplo n.º 7
0
    def __init__(self,
                 size,
                 patch_size,
                 in_channels,
                 embedding_dim,
                 depth=1,
                 heads=8,
                 dim_head=8,
                 dim_mlp=64):
        super(ConvTokenConvTransformer, self).__init__()
        self.size = size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embedding_dim = embedding_dim
        self.depth = depth
        self.heads = heads
        self.dim_head = dim_head
        self.dim_mlp = dim_mlp

        self.patch_height = int(size[0] / patch_size[0])
        self.patch_width = int(size[1] / patch_size[1])

        self.embedding = ConvPatchEmbbeding(in_channels=self.in_channels,
                                            patch_size=self.patch_size,
                                            embbeding_dim=self.embedding_dim)
        self.layernorm = nn.LayerNorm(self.embedding_dim)

        self.rearrange = Rearrange('b (h w) d -> b d h w',
                                   h=self.patch_height,
                                   w=self.patch_width)
        self.attention = SelfConvTransfomer(size=(self.patch_height,
                                                  self.patch_width),
                                            in_channels=self.embedding_dim,
                                            depth=self.depth,
                                            heads=self.heads,
                                            dim_head=self.dim_head,
                                            dim_mlp=self.dim_mlp)
Exemplo n.º 8
0
    def __init__(self,
                 image_size=224,
                 tokens_type='performer',
                 in_chans=3,
                 embed_dim=768,
                 dropout=0.1,
                 t2t_layers=((7, 4), (3, 2), (3, 2))):
        super().__init__()
        layers = []
        layer_dim = in_chans
        output_image_size = image_size

        for i, (kernel_size, stride) in enumerate(t2t_layers):
            layer_dim *= kernel_size**2
            is_first = i == 0
            output_image_size = conv_output_size(output_image_size,
                                                 kernel_size, stride,
                                                 stride // 2)

            layers.extend([
                RearrangeImage() if not is_first else nn.Identity(),
                nn.Unfold(kernel_size=kernel_size,
                          stride=stride,
                          padding=stride // 2),
                Rearrange('b c n -> b n c'),
                Transformer(dim=layer_dim,
                            heads=1,
                            depth=1,
                            dim_head=layer_dim,
                            mlp_dim=layer_dim,
                            dropout=dropout)
                if tokens_type == 'transformer' else Performer(
                    dim=layer_dim, inner_dim=layer_dim, kernel_ratio=0.5),
            ])
        layers.append(nn.Linear(layer_dim, embed_dim))
        self.output_image_size = output_image_size
        self.to_patch_embedding = nn.Sequential(*layers)
Exemplo n.º 9
0
    def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
        super().__init__()
        image_size_h, image_size_w = pair(image_size)
        assert image_size_h % patch_size == 0 and image_size_w % patch_size == 0, 'image dimensions must be divisible by the patch size'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        num_patches = (image_size_h // patch_size) * (image_size_w // patch_size)
        patch_dim = channels * patch_size ** 2

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.transformer = transformer

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
Exemplo n.º 10
0
    def __init__(self,
                 *,
                 image_size,
                 patch_size,
                 num_classes,
                 dim,
                 depth,
                 heads,
                 mlp_dim,
                 channels=3,
                 dim_head=64,
                 dropout=0.,
                 emb_dropout=0.,
                 use_rotary=True,
                 use_ds_conv=True,
                 use_glu=True):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size)**2
        patch_dim = channels * patch_size**2

        self.patch_size = patch_size
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                      p1=patch_size,
                      p2=patch_size),
            nn.Linear(patch_dim, dim),
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim,
                                       dropout, use_rotary, use_ds_conv,
                                       use_glu)

        self.mlp_head = nn.Sequential(nn.LayerNorm(dim),
                                      nn.Linear(dim, num_classes))
Exemplo n.º 11
0
def G_transformer(incoming,
                  dim,
                  heads,
                  dim_head,
                  dropout,
                  mlp_dim,
                  curr_patchsize,
                  channel,
                  curr_dim,
                  curr_num,
                  num_channels,
                  to_dim,
                  to_sequential=True,
                  use_wscale=True,
                  use_pixelnorm=True):
    layers = incoming
    layers += [  #attention input b (h w) (p1 p2 c)
        Rearrange('b c (p1 h) (p2 w) -> b (h w) (p1 p2 c)',
                  p1=curr_patchsize,
                  p2=curr_patchsize,
                  c=channel,
                  h=curr_num,
                  w=curr_num),
        # Residual(
        Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout),
        # ),
        #    Residual(
        #  FeedForward(dim, mlp_dim, dropout = dropout),
        FeedForward(dim, mlp_dim, dropout=dropout),
        #  ),
        # nn.LeakyReLU(negative_slope=0.2),
        Rearrange('b (h w) (p1 p2 c) -> b c (p1 h) (p2 w)',
                  p1=curr_patchsize,
                  p2=curr_patchsize,
                  c=channel,
                  h=curr_num,
                  w=curr_num)
        #    nn.GELU(),
        #    )
    ]
    #    Norm_Linear(dim, to_dim),
    #    Norm_Linear(2048, 2048),
    #    Norm_Linear(2048, to_dim),
    #    Rearrange('b (h w) (p1 p2 c) -> b c (p1 h) (p2 w)', p1 = curr_patchsize, p2 = curr_patchsize,c=num_channels,h=curr_num,w=curr_num)]
    # layers1=[]
    # output -->b c h w
    # he_init(layers[-1], init, param)  # init layers
    # if use_wscale:
    #     for i,value in enumerate(layers):
    #         print('G_Trans i {}'.format(value))
    #         layers1 += [Trans_WScaleLayer(value)]
    # layers += [nonlinearity]
    # if use_batchnorm:
    # layers += [nn.BatchNorm2d(out_channels)]
    if use_pixelnorm:
        layers += [PixelNormLayer()]
    # layers1 = layers
    # layers = incoming+layers1
    if to_sequential:
        return nn.Sequential(*layers)
    else:
        return layers
Exemplo n.º 12
0
    def __init__(self,
                 *,
                 image_size,
                 patch_dim,
                 pixel_dim,
                 patch_size,
                 pixel_size,
                 depth,
                 num_classes,
                 heads=8,
                 dim_head=64,
                 ff_dropout=0.,
                 attn_dropout=0.,
                 unfold_args=None):
        super().__init__()
        assert divisible_by(
            image_size,
            patch_size), 'image size must be divisible by patch size'
        assert divisible_by(
            patch_size,
            pixel_size), 'patch size must be divisible by pixel size for now'

        num_patch_tokens = (image_size // patch_size)**2

        self.image_size = image_size
        self.patch_size = patch_size
        self.patch_tokens = nn.Parameter(
            torch.randn(num_patch_tokens + 1, patch_dim))

        unfold_args = default(unfold_args, (pixel_size, pixel_size, 0))
        unfold_args = (*unfold_args,
                       0) if len(unfold_args) == 2 else unfold_args
        kernel_size, stride, padding = unfold_args

        pixel_width = unfold_output_size(patch_size, kernel_size, stride,
                                         padding)
        num_pixels = pixel_width**2

        self.to_pixel_tokens = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> (b h w) c p1 p2',
                      p1=patch_size,
                      p2=patch_size),
            nn.Unfold(kernel_size=kernel_size, stride=stride, padding=padding),
            Rearrange('... c n -> ... n c'),
            nn.Linear(3 * kernel_size**2, pixel_dim))

        self.patch_pos_emb = nn.Parameter(
            torch.randn(num_patch_tokens + 1, patch_dim))
        self.pixel_pos_emb = nn.Parameter(torch.randn(num_pixels, pixel_dim))

        layers = nn.ModuleList([])
        for _ in range(depth):

            pixel_to_patch = nn.Sequential(
                RMSNorm(pixel_dim),
                Rearrange('... n d -> ... (n d)'),
                nn.Linear(pixel_dim * num_pixels, patch_dim),
            )

            layers.append(
                nn.ModuleList([
                    PreNorm(
                        pixel_dim,
                        Attention(dim=pixel_dim,
                                  heads=heads,
                                  dim_head=dim_head,
                                  dropout=attn_dropout)),
                    PreNorm(pixel_dim,
                            FeedForward(dim=pixel_dim, dropout=ff_dropout)),
                    pixel_to_patch,
                    PreNorm(
                        patch_dim,
                        Attention(dim=patch_dim,
                                  heads=heads,
                                  dim_head=dim_head,
                                  dropout=attn_dropout)),
                    PreNorm(patch_dim,
                            FeedForward(dim=patch_dim, dropout=ff_dropout)),
                ]))

        self.layers = layers

        self.mlp_head = nn.Sequential(RMSNorm(patch_dim),
                                      nn.Linear(patch_dim, num_classes))
Exemplo n.º 13
0
    def __init__(
        self,
        image_size=224,
        patch_size=16,
        channels=3,
        embedding_dim=768,
        hidden_dims=None,
        conv_patch=False,
        linear_patch=False,
        conv_stem=True,
        conv_stem_original=True,
        conv_stem_scaled_relu=False,
        position_embedding_dropout=None,
        cls_head=True,
    ):
        super(EmbeddingStem, self).__init__()

        assert (sum([conv_patch, conv_stem, linear_patch
                     ]) == 1), "Only one of three modes should be active"

        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert (image_height % patch_height == 0 and image_width % patch_width
                == 0), "Image dimensions must be divisible by the patch size."

        assert not (
            conv_stem and cls_head
        ), "Cannot use [CLS] token approach with full conv stems for ViT"

        if linear_patch or conv_patch:
            self.grid_size = (
                image_height // patch_height,
                image_width // patch_width,
            )
            num_patches = self.grid_size[0] * self.grid_size[1]

            if cls_head:
                self.cls_token = nn.Parameter(torch.zeros(1, 1, embedding_dim))
                num_patches += 1

            # positional embedding
            self.pos_embed = nn.Parameter(
                torch.zeros(1, num_patches, embedding_dim))
            self.pos_drop = nn.Dropout(p=position_embedding_dropout)

        if conv_patch:
            self.projection = nn.Sequential(
                nn.Conv2d(
                    channels,
                    embedding_dim,
                    kernel_size=patch_size,
                    stride=patch_size,
                ), )
        elif linear_patch:
            patch_dim = channels * patch_height * patch_width
            self.projection = nn.Sequential(
                Rearrange(
                    'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                    p1=patch_height,
                    p2=patch_width,
                ),
                nn.Linear(patch_dim, embedding_dim),
            )
        elif conv_stem:
            assert (conv_stem_scaled_relu ^ conv_stem_original
                    ), "Can use either the original or the scaled relu stem"

            if not isinstance(hidden_dims, list):
                raise ValueError("Cannot create stem without list of sizes")

            if conv_stem_original:
                """
                Conv stem from https://arxiv.org/pdf/2106.14881.pdf
                """

                hidden_dims.insert(0, channels)
                modules = []
                for i, (in_ch, out_ch) in enumerate(
                        zip(hidden_dims[:-1], hidden_dims[1:])):
                    modules.append(
                        nn.Conv2d(
                            in_ch,
                            out_ch,
                            kernel_size=3,
                            stride=2 if in_ch != out_ch else 1,
                            padding=1,
                            bias=False,
                        ), )
                    modules.append(nn.BatchNorm2d(out_ch), )
                    modules.append(nn.ReLU(inplace=True))

                modules.append(
                    nn.Conv2d(
                        hidden_dims[-1],
                        embedding_dim,
                        kernel_size=1,
                        stride=1,
                    ), )
                self.projection = nn.Sequential(*modules)

            elif conv_stem_scaled_relu:
                """
                Conv stem from https://arxiv.org/pdf/2109.03810.pdf
                """
                assert (len(hidden_dims) == 1
                        ), "Only one value for hidden_dim is allowed"
                mid_ch = hidden_dims[0]

                # fmt: off
                self.projection = nn.Sequential(
                    nn.Conv2d(
                        channels,
                        mid_ch,
                        kernel_size=7,
                        stride=2,
                        padding=3,
                        bias=False,
                    ),
                    nn.BatchNorm2d(mid_ch),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(
                        mid_ch,
                        mid_ch,
                        kernel_size=3,
                        stride=1,
                        padding=1,
                        bias=False,
                    ),
                    nn.BatchNorm2d(mid_ch),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(
                        mid_ch,
                        mid_ch,
                        kernel_size=3,
                        stride=1,
                        padding=1,
                        bias=False,
                    ),
                    nn.BatchNorm2d(mid_ch),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(
                        mid_ch,
                        embedding_dim,
                        kernel_size=patch_size // 2,
                        stride=patch_size // 2,
                    ),
                )
                # fmt: on

            else:
                raise ValueError("Undefined convolutional stem type defined")

        self.conv_stem = conv_stem
        self.conv_patch = conv_patch
        self.linear_patch = linear_patch
        self.cls_head = cls_head

        self._init_weights()
Exemplo n.º 14
0
    def __init__(self,
                 *,
                 image_size,
                 patch_size,
                 num_classes,
                 dim,
                 heads,
                 num_hierarchies,
                 block_repeats,
                 mlp_mult=4,
                 channels=3,
                 dim_head=64,
                 dropout=0.):
        super().__init__()
        assert (image_size % patch_size
                ) == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size)**2
        patch_dim = channels * patch_size**2
        fmap_size = image_size // patch_size
        blocks = 2**(num_hierarchies - 1)

        seq_len = (
            fmap_size //
            blocks)**2  # sequence length is held constant across heirarchy
        hierarchies = list(reversed(range(num_hierarchies)))
        mults = [2**i for i in reversed(hierarchies)]

        layer_heads = list(map(lambda t: t * heads, mults))
        layer_dims = list(map(lambda t: t * dim, mults))
        last_dim = layer_dims[-1]

        layer_dims = [*layer_dims, layer_dims[-1]]
        dim_pairs = zip(layer_dims[:-1], layer_dims[1:])

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w',
                      p1=patch_size,
                      p2=patch_size),
            nn.Conv2d(patch_dim, layer_dims[0], 1),
        )

        block_repeats = cast_tuple(block_repeats, num_hierarchies)

        self.layers = nn.ModuleList([])

        for level, heads, (dim_in, dim_out), block_repeat in zip(
                hierarchies, layer_heads, dim_pairs, block_repeats):
            is_last = level == 0
            depth = block_repeat

            self.layers.append(
                nn.ModuleList([
                    Transformer(dim_in, seq_len, depth, heads, mlp_mult,
                                dropout),
                    Aggregate(dim_in, dim_out)
                    if not is_last else nn.Identity()
                ]))

        self.mlp_head = nn.Sequential(LayerNorm(last_dim),
                                      Reduce('b c h w -> b c', 'mean'),
                                      nn.Linear(last_dim, num_classes))
Exemplo n.º 15
0
def DynamicPositionBias(dim):
    return nn.Sequential(nn.Linear(2, dim), nn.LayerNorm(dim), nn.ReLU(),
                         nn.Linear(dim, dim), nn.LayerNorm(dim), nn.ReLU(),
                         nn.Linear(dim, dim), nn.LayerNorm(dim), nn.ReLU(),
                         nn.Linear(dim, 1), Rearrange('... () -> ...'))
Exemplo n.º 16
0
# --------------------------------------------------------'
import random
import math
import numpy as np

import os
import torchvision.transforms as transforms
from einops.layers.torch import Rearrange
import torch
from PIL import Image

input_size = 224
patch_height, patch_width = 16, 16
h, w = int(input_size / patch_height), int(input_size / patch_width)
to_patch_embedding_2d = Rearrange(
    "(h p1) (w p2) -> (h w) (p1 p2)", p1=patch_height, p2=patch_width
)
_FOC_MASK_PATH = os.path.join("/disk1/data", "mask", "foc")


def haved_masked_lst(fname):
    mask_fname = os.path.join(_FOC_MASK_PATH, fname)
    image = Image.open(mask_fname)
    image = transforms.Resize([input_size, input_size])(image)
    image_arr = np.array(image)
    patch_image = to_patch_embedding_2d(torch.tensor(image_arr, dtype=torch.float32))
    image.close()

    index_lst = []
    for i in range(h * w):
        x = np.count_nonzero((patch_image[i, :] > 100))  # 非黑元素
Exemplo n.º 17
0
 def __init__(self, size):
     self.rearrange = Rearrange("c (h p1) (w p2) -> (h w) (p1 p2 c)",
                                p1=size,
                                p2=size)
Exemplo n.º 18
0
    def __init__(self,
                 *,
                 image_size,
                 patch_dim,
                 pixel_dim,
                 patch_size,
                 pixel_size,
                 depth,
                 num_classes,
                 heads=8,
                 dim_head=64,
                 ff_dropout=0.,
                 attn_dropout=0.):
        super().__init__()
        assert image_size % patch_size == 0, 'image size must be divisible by patch size'
        assert patch_size % pixel_size == 0, 'patch size must be divisible by pixel size for now'

        num_patch_tokens = (image_size // patch_size)**2
        pixel_width = patch_size // pixel_size
        num_pixels = pixel_width**2

        self.image_size = image_size
        self.patch_size = patch_size
        self.patch_tokens = nn.Parameter(
            torch.randn(num_patch_tokens + 1, patch_dim))

        self.to_pixel_tokens = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> (b h w) c p1 p2',
                      p1=patch_size,
                      p2=patch_size), nn.Unfold(pixel_size, stride=pixel_size),
            Rearrange('... c n -> ... n c'),
            nn.Linear(4 * pixel_size**2, pixel_dim))

        self.patch_pos_emb = nn.Parameter(
            torch.randn(num_patch_tokens + 1, patch_dim))
        self.pixel_pos_emb = nn.Parameter(torch.randn(num_pixels, pixel_dim))

        layers = nn.ModuleList([])
        for _ in range(depth):

            pixel_to_patch = nn.Sequential(
                nn.LayerNorm(pixel_dim),
                Rearrange('... n d -> ... (n d)'),
                nn.Linear(pixel_dim * num_pixels, patch_dim),
            )

            layers.append(
                nn.ModuleList([
                    PreNorm(
                        pixel_dim,
                        Attention(dim=pixel_dim,
                                  heads=heads,
                                  dim_head=dim_head,
                                  dropout=attn_dropout)),
                    PreNorm(pixel_dim,
                            FeedForward(dim=pixel_dim, dropout=ff_dropout)),
                    pixel_to_patch,
                    PreNorm(
                        patch_dim,
                        Attention(dim=patch_dim,
                                  heads=heads,
                                  dim_head=dim_head,
                                  dropout=attn_dropout)),
                    PreNorm(patch_dim,
                            FeedForward(dim=patch_dim, dropout=ff_dropout)),
                ]))

        self.layers = layers

        self.mlp_head = nn.Sequential(nn.LayerNorm(patch_dim),
                                      nn.Linear(patch_dim, num_classes))
Exemplo n.º 19
0

def dataset_with_indices(cls):
    """
  Returns a modified class cls, which returns tuples like (X, y, indices) instead of just (X, y).
  """
    def __getitem__(self, index):
        data, target = cls.__getitem__(self, index)
        return data, target, index

    return type(cls.__name__, (cls, ), {"__getitem__": __getitem__})


FashionMNIST = dataset_with_indices(FashionMNIST)

model = nn.Sequential(Rearrange("b () h w -> b (h w)"), nn.Linear(28**2, 10))
DATASET_PATH = "/mnt/hdd_1tb/datasets/fashionmnist"
dataset = FashionMNIST(DATASET_PATH,
                       transform=Compose(
                           (ToTensor(), Normalize((0.286, ), (0.353, )))))
TRAIN_SIZE = 50000
BATCH_SIZE = 512
DEVICE = torch.device("cuda")
train_dataloader = DataLoader(
    dataset,
    BATCH_SIZE,
    sampler=SubsetRandomSampler(range(TRAIN_SIZE)),
    pin_memory=(DEVICE.type == "cuda"),
    drop_last=True,
)
val_dataloader = DataLoader(
Exemplo n.º 20
0
    def __init__(
            self,
            *,
            dim,
            max_seq_len=2048,
            depth=6,
            heads=8,
            dim_head=64,
            attn_types=('full', ),
            num_tokens=constants.NUM_AMINO_ACIDS,
            num_embedds=constants.NUM_EMBEDDS_TR,
            max_num_msas=constants.MAX_NUM_MSA,
            max_num_templates=constants.MAX_NUM_TEMPLATES,
            attn_dropout=0.,
            ff_dropout=0.,
            reversible=False,
            sparse_self_attn=False,
            cross_attn_compress_ratio=1,
            msa_tie_row_attn=False,
            template_attn_depth=2,
            num_backbone_atoms=1,  # number of atoms to reconstitute each residue to, defaults to 3 for C, C-alpha, N
            predict_angles=False,
            symmetrize_omega=False,
            predict_coords=False,  # structure module related keyword arguments below
            predict_real_value_distances=False,
            trunk_embeds_to_se3_edges=0,  # feeds pairwise projected logits from the trunk embeddings into the equivariant transformer as edges
            se3_edges_fourier_encodings=4,  # number of fourier encodings for se3 edges
            return_aux_logits=False,
            mds_iters=5,
            use_se3_transformer=True,  # uses SE3 Transformer - but if set to false, will use the new E(n)-Transformer
            structure_module_dim=4,
            structure_module_depth=4,
            structure_module_heads=1,
            structure_module_dim_head=4,
            structure_module_refinement_iters=2,
            structure_module_knn=0,
            structure_module_adj_neighbors=2):
        super().__init__()
        assert num_backbone_atoms in {
            1, 3, 4
        }, 'must be either residue level, or reconstitute to atomic coordinates of 3 for the C, Ca, N of backbone, or 4 of C-beta as well'

        layers_sparse_attn = cast_tuple(sparse_self_attn, depth)

        self.token_emb = nn.Embedding(num_tokens, dim)

        # template embedding

        self.template_dist_emb = nn.Embedding(constants.DISTOGRAM_BUCKETS, dim)
        self.template_num_pos_emb = nn.Embedding(max_num_templates, dim)

        # projection for angles, if needed

        self.predict_angles = predict_angles
        self.symmetrize_omega = symmetrize_omega

        if predict_angles:
            self.to_prob_theta = nn.Linear(dim, constants.THETA_BUCKETS)
            self.to_prob_phi = nn.Linear(dim, constants.PHI_BUCKETS)
            self.to_prob_omega = nn.Linear(dim, constants.OMEGA_BUCKETS)

        # when predicting the coordinates, whether to return the other logits, distogram (and optionally, angles)

        self.return_aux_logits = return_aux_logits

        # template sidechain encoding

        self.use_se3_transformer = use_se3_transformer

        if use_se3_transformer:
            self.template_sidechain_emb = SE3TemplateEmbedder(dim=dim,
                                                              dim_head=dim,
                                                              heads=1,
                                                              num_neighbors=12,
                                                              depth=4,
                                                              input_degrees=2,
                                                              num_degrees=2,
                                                              output_degrees=1,
                                                              reversible=True)
        else:
            self.template_sidechain_emb = EnTransformer(
                dim=dim,
                dim_head=dim,
                heads=1,
                num_nearest_neighbors=32,
                depth=4)

        # custom embedding projection

        self.embedd_project = nn.Linear(num_embedds, dim)

        # main trunk modules

        prenorm = partial(PreNorm, dim)
        prenorm_cross = partial(PreNormCross, dim)

        layers = nn.ModuleList([])
        attn_types = islice(cycle(attn_types), depth)

        for ind, layer_sparse_attn, attn_type in zip(range(depth),
                                                     layers_sparse_attn,
                                                     attn_types):

            # alternate between row and column attention to save memory each layer

            row_attn = ind % 2 == 0
            col_attn = ind % 2 == 1

            # self attention, for main sequence, msa, and optionally, templates

            if attn_type == 'full':
                tensor_slice = None
                template_axial_attn = True
            elif attn_type == 'intra_attn':
                tensor_slice = None
                template_axial_attn = False
            elif attn_type == 'seq_only':
                tensor_slice = (slice(None), slice(0, 1))
                template_axial_attn = False
            else:
                raise ValueError(f'cannot find attention type {attn_type}')

            layers.append(
                nn.ModuleList([
                    prenorm(
                        InterceptAxialAttention(
                            tensor_slice,
                            AxialAttention(
                                dim=dim,
                                template_axial_attn=template_axial_attn,
                                seq_len=max_seq_len,
                                heads=heads,
                                dim_head=dim_head,
                                dropout=attn_dropout,
                                sparse_attn=sparse_self_attn,
                                row_attn=row_attn,
                                col_attn=col_attn,
                                rotary_rpe=True))),
                    prenorm(
                        InterceptFeedForward(
                            tensor_slice,
                            ff=FeedForward(dim=dim, dropout=ff_dropout))),
                    prenorm(
                        AxialAttention(dim=dim,
                                       seq_len=max_seq_len,
                                       heads=heads,
                                       dim_head=dim_head,
                                       dropout=attn_dropout,
                                       tie_row_attn=msa_tie_row_attn,
                                       row_attn=row_attn,
                                       col_attn=col_attn,
                                       rotary_rpe=True)),
                    prenorm(FeedForward(dim=dim, dropout=ff_dropout)),
                ]))

            # cross attention, for main sequence -> msa and then msa -> sequence

            intercept_fn = partial(InterceptAttention,
                                   (slice(None), slice(0, 1)))

            layers.append(
                nn.ModuleList([
                    intercept_fn(
                        context=False,
                        attn=prenorm_cross(
                            Attention(
                                dim=dim,
                                seq_len=max_seq_len,
                                heads=heads,
                                dim_head=dim_head,
                                dropout=attn_dropout,
                                compress_ratio=cross_attn_compress_ratio))),
                    prenorm(FeedForward(dim=dim, dropout=ff_dropout)),
                    intercept_fn(
                        context=True,
                        attn=prenorm_cross(
                            Attention(
                                dim=dim,
                                seq_len=max_seq_len,
                                heads=heads,
                                dim_head=dim_head,
                                dropout=attn_dropout,
                                compress_ratio=cross_attn_compress_ratio))),
                    prenorm(FeedForward(dim=dim, dropout=ff_dropout)),
                ]))

        if not reversible:
            layers = nn.ModuleList(list(
                map(lambda t: t[:3],
                    layers)))  # remove last feed forward if not reversible

        trunk_class = SequentialSequence if not reversible else ReversibleSequence
        self.net = trunk_class(layers)

        # to distogram output

        self.num_backbone_atoms = num_backbone_atoms
        needs_upsample = num_backbone_atoms > 1

        self.predict_real_value_distances = predict_real_value_distances
        dim_distance_pred = constants.DISTOGRAM_BUCKETS if not predict_real_value_distances else 2  # 2 for predicting mean and standard deviation values of real-value distance

        self.to_distogram_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Sequential(nn.Linear(dim, dim * (num_backbone_atoms**2)),
                          Rearrange('b h w c -> b c h w'),
                          nn.PixelShuffle(num_backbone_atoms),
                          Rearrange('b c h w -> b h w c'))
            if needs_upsample else nn.Identity(),
            nn.Linear(dim, dim_distance_pred))

        # to coordinate output

        self.predict_coords = predict_coords
        self.mds_iters = mds_iters
        self.structure_module_refinement_iters = structure_module_refinement_iters

        self.trunk_to_structure_dim = nn.Linear(dim, structure_module_dim)

        self.trunk_embeds_to_se3_edges = trunk_embeds_to_se3_edges
        self.se3_edges_fourier_encodings = se3_edges_fourier_encodings
        edge_dim = (
            (2 * se3_edges_fourier_encodings) + 1) * trunk_embeds_to_se3_edges

        self.to_equivariant_net_edges = nn.Linear(
            dim, trunk_embeds_to_se3_edges
        ) if trunk_embeds_to_se3_edges > 0 else None

        with torch_default_dtype(torch.float64):
            self.structure_module_embeds = nn.Embedding(
                num_tokens, structure_module_dim)
            self.atom_tokens_embed = nn.Embedding(len(ATOM_IDS),
                                                  structure_module_dim)

            if use_se3_transformer:
                self.structure_module = SE3TransformerWrapper(
                    dim=structure_module_dim,
                    depth=structure_module_depth,
                    input_degrees=1,
                    num_degrees=3,
                    output_degrees=2,
                    heads=structure_module_heads,
                    differentiable_coors=True,
                    num_neighbors=0,  # use only bonded neighbors for now
                    attend_sparse_neighbors=True,
                    edge_dim=edge_dim,
                    num_adj_degrees=structure_module_adj_neighbors,
                    adj_dim=4,
                )
            else:
                self.structure_module = EnTransformer(
                    dim=structure_module_dim,
                    depth=structure_module_depth,
                    heads=structure_module_heads,
                    fourier_features=2,
                    num_nearest_neighbors=0,
                    only_sparse_neighbors=True,
                    edge_dim=edge_dim,
                    num_adj_degrees=structure_module_adj_neighbors,
                    adj_dim=4)

        # aux confidence measure
        self.lddt_linear = nn.Linear(structure_module_dim, 1)
    def __init__(
            self,
            num_channels=3,  # Overridden based on dataset.
            resolution=32,  # Overridden based on dataset.
            first_resolution=4,  # Overridden based on dataset.
            label_size=0,  # Overridden based on dataset.
            fmap_base=4096,
            fmap_decay=1.0,
            fmap_max=256,
            dim=3072,  #可变参数
            base_size=2,
            max_patch=32,
            channel=512,
            heads=8,
            dim_head=64,
            mlp_dim=2048,  #可变
            dropout=0.0,
            latent_size=None,
            normalize_latents=True,
            use_wscale=True,
            use_pixelnorm=True,
            use_leakyrelu=True,
            use_batchnorm=False,
            tanh_at_end=None):
        super(Generator, self).__init__()
        self.num_channels = num_channels
        self.resolution = resolution
        self.label_size = label_size
        self.fmap_base = fmap_base
        self.fmap_decay = fmap_decay
        self.fmap_max = fmap_max
        self.latent_size = latent_size
        self.normalize_latents = normalize_latents
        self.use_wscale = use_wscale
        self.use_pixelnorm = use_pixelnorm
        self.use_leakyrelu = use_leakyrelu
        self.use_batchnorm = use_batchnorm
        self.tanh_at_end = tanh_at_end
        self.curr_resol = first_resolution

        R = int(np.log2(resolution))
        assert resolution == 2**R and resolution >= 4
        if latent_size is None:
            latent_size = self.get_nf(0)

        negative_slope = 0.2
        act = nn.LeakyReLU(negative_slope=negative_slope
                           ) if self.use_leakyrelu else nn.ReLU()
        iact = 'leaky_relu' if self.use_leakyrelu else 'relu'
        output_act = nn.Tanh() if self.tanh_at_end else 'linear'
        output_iact = 'tanh' if self.tanh_at_end else 'linear'

        pre = None
        lods = nn.ModuleList()
        nins = nn.ModuleList()
        layers = []

        if self.normalize_latents:
            pre = PixelNormLayer()

        if self.label_size:
            layers += [ConcatLayer()]

        # layers += [ReshapeLayer([int(latent_size/(first_resolution**2)), first_resolution, first_resolution])]
        # layers = G_conv(layers, latent_size, self.get_nf(1), 4, 3, act, iact, negative_slope,
        # False, self.use_wscale, self.use_batchnorm, self.use_pixelnorm)
        #第一层换成MLP 从512--》4096
        # layers += [nn.Linear(latent_size, 1024)]
        layers += [nn.Linear(latent_size, 4096)]  #b*4096 --> b*256*4*4
        # net = G_conv(layers, latent_size, self.get_nf(1), 3, 1, act, iact, negative_slope,
        # True, self.use_wscale, self.use_batchnorm, self.use_pixelnorm)  # first block
        # layers += [nn.Linear(4096, 4096)]
        ## input reshape b * dim --> b c h w

        curr_patchsize, curr_dim, curr_num = self.get_patch_dim(
            1, base_size, max_patch, latent_size // 2, first_resolution)
        print('curr_dim {}, curr_patchsize {}'.format(curr_dim,
                                                      curr_patchsize))
        ## b * 4096 -->b*256*4*4-->b*(2*2)*(256*2*2)
        layers += [
            Rearrange('b (c p1 h p2 w) -> b c (p1 h) (p2 w)',
                      p1=curr_patchsize,
                      p2=curr_patchsize,
                      c=latent_size // 2,
                      h=curr_num,
                      w=curr_num)
        ]
        net = G_transformer(layers, dim, heads, dim_head, dropout, mlp_dim,
                            curr_patchsize, latent_size // 2, curr_dim,
                            curr_num, num_channels,
                            num_channels * curr_patchsize * curr_patchsize)

        # lods.append(to_patch_embedding)
        # net =
        lods.append(net)
        # nins.append(NINLayer([], self.get_nf(1), self.num_channels, output_act, output_iact, None, True, self.use_wscale))  # to_rgb layer
        nins.append(
            NIN_transformer([], dim, heads, dim_head, dropout, mlp_dim,
                            curr_patchsize, num_channels, curr_num, curr_dim,
                            num_channels * curr_patchsize * curr_patchsize))

        # nins.append(to_patch_image)
        for I in range(2, R):  # following blocks
            # ic, oc = self.get_nf(I-1), self.get_nf(I)
            curr_patchsize, curr_dim, curr_num = self.get_patch_dim(
                I, base_size, max_patch, num_channels, first_resolution)
            print('following curr_dim {}, curr_patchsize {}'.format(
                curr_dim, curr_patchsize))
            layers = [nn.Upsample(scale_factor=2, mode='nearest')]  # upsample

            layers = G_transformer(layers, dim, heads, dim_head, dropout,
                                   mlp_dim, curr_patchsize, num_channels,
                                   curr_dim, curr_num, num_channels, curr_dim)
            # layers = G_conv(layers, ic, oc, 3, 1, act, iact, negative_slope, False, self.use_wscale, self.use_batchnorm, self.use_pixelnorm)

            # net = G_conv(layers, oc, oc, 3, 1, act, iact, negative_slope, True, self.use_wscale, self.use_batchnorm, self.use_pixelnorm)
            net = layers
            lods.append(net)
            nins.append(
                NIN_transformer([], dim, heads, dim_head, dropout, mlp_dim,
                                curr_patchsize, num_channels, curr_num,
                                curr_dim, curr_dim))
            # nins.append(NINLayer([], oc, self.num_channels, output_act, output_iact, None, True, self.use_wscale))  # to_rgb layer

        self.output_layer = GSelectLayer(pre, lods, nins)
Exemplo n.º 22
0
    def __init__(self, dev):
        super(scattering, self).__init__()

        self.rearange0 = Rearrange(' a b c d -> a c (b d)')
        self.dev = dev
Exemplo n.º 23
0
small_pic_batch = torch.ones((1, 1, 6, 6))
small_pic_batch = small_pic_batch.view((1, 1, 6 * 6))

w_query = nn.Linear(6 * 6, 6 * 6, bias=False)
w_key = nn.Linear(6 * 6, 6 * 6, bias=False)
w_value = nn.Linear(6 * 6, 6 * 6, bias=False)

w_query.eval()
w_key.eval()
w_value.eval()

# att
soft = nn.Softmax(dim=4)

mem_blocks_divider = Rearrange('b c (h p1) (w p2) -> b c (h w) (p1 p2)',
                               p1=2,
                               p2=2)

# TODO Optimize
query = mem_blocks_divider(w_query(small_pic_batch).view((1, 1, 6, 6))).view(
    (1, 1, 9, 4, 1))
key = mem_blocks_divider(w_key(small_pic_batch).view((1, 1, 6, 6)))
value = mem_blocks_divider(w_value(small_pic_batch).view((1, 1, 6, 6))).view(
    (1, 1, 9, 4, 1))

# q(i,j) * k(a,b) where i,j are pixel col and row AND a,b are memory block size
att_score = soft(torch.einsum('b c n g q, b c n p ->b c n g p', query, key))

# print(torch.matmul(att_score, value).view(1,1,9*4))

# avg pooling
Exemplo n.º 24
0
    def __init__(self,
                 image_size,
                 channels,
                 num_classes,
                 patch_size_small=14,
                 patch_size_large=16,
                 small_dim=96,
                 large_dim=192,
                 small_depth=1,
                 large_depth=4,
                 cross_attn_depth=1,
                 multi_scale_enc_depth=3,
                 heads=3,
                 pool='cls',
                 dropout=0.,
                 emb_dropout=0.,
                 scale_dim=4):
        super().__init__()

        assert image_size % patch_size_small == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches_small = (image_size // patch_size_small)**2
        patch_dim_small = channels * patch_size_small**2

        assert image_size % patch_size_large == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches_large = (image_size // patch_size_large)**2
        patch_dim_large = channels * patch_size_large**2
        assert pool in {
            'cls', 'mean'
        }, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding_small = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                      p1=patch_size_small,
                      p2=patch_size_small),
            nn.Linear(patch_dim_small, small_dim),
        )

        self.to_patch_embedding_large = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                      p1=patch_size_large,
                      p2=patch_size_large),
            nn.Linear(patch_dim_large, large_dim),
        )

        self.pos_embedding_small = nn.Parameter(
            torch.randn(1, num_patches_small + 1, small_dim))
        self.cls_token_small = nn.Parameter(torch.randn(1, 1, small_dim))
        self.dropout_small = nn.Dropout(emb_dropout)

        self.pos_embedding_large = nn.Parameter(
            torch.randn(1, num_patches_large + 1, large_dim))
        self.cls_token_large = nn.Parameter(torch.randn(1, 1, large_dim))
        self.dropout_large = nn.Dropout(emb_dropout)

        self.multi_scale_transformers = nn.ModuleList([])
        for _ in range(multi_scale_enc_depth):
            self.multi_scale_transformers.append(
                MultiScaleTransformerEncoder(
                    small_dim=small_dim,
                    small_depth=small_depth,
                    small_heads=heads,
                    small_dim_head=small_dim // heads,
                    small_mlp_dim=small_dim * scale_dim,
                    large_dim=large_dim,
                    large_depth=large_depth,
                    large_heads=heads,
                    large_dim_head=large_dim // heads,
                    large_mlp_dim=large_dim * scale_dim,
                    cross_attn_depth=cross_attn_depth,
                    cross_attn_heads=heads,
                    dropout=dropout))

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head_small = nn.Sequential(nn.LayerNorm(small_dim),
                                            nn.Linear(small_dim, num_classes))

        self.mlp_head_large = nn.Sequential(nn.LayerNorm(large_dim),
                                            nn.Linear(large_dim, num_classes))
                              x_val,
                              y_val,
                              clf=dummy_clf,
                              dataset_name=dataset_name))

        print("DONE:", data_dir)
        return results
    except:
        return []
    # print("\t", pd.DataFrame(results[2*i:2*i+2]))


if __name__ == "__main__":
    data_dirs = sorted(list(glob.glob('datasets/*')))

    transform = tfms.Compose([
        tfms.Grayscale(),
        tfms.Resize(128, interpolation=2),
        tfms.RandomCrop(112),
        tfms.ToTensor(),
        Rearrange('h w c -> (h w c)'),
    ])

    print(len(data_dirs))
    result = Parallel(n_jobs=5)(delayed(parfunc)(data_dir)
                                for data_dir in data_dirs)

    with open("svm_dummy.txt", "wb") as fp:
        pickle.dump(result, fp)

    pd.DataFrame(result).to_csv('svm_dummy_results.csv')
Exemplo n.º 26
0
    def __init__(self,
                 *,
                 num_classes,
                 s1_emb_dim=64,
                 s1_patch_size=4,
                 s1_local_patch_size=7,
                 s1_global_k=7,
                 s1_depth=1,
                 s2_emb_dim=128,
                 s2_patch_size=2,
                 s2_local_patch_size=7,
                 s2_global_k=7,
                 s2_depth=1,
                 s3_emb_dim=256,
                 s3_patch_size=2,
                 s3_local_patch_size=7,
                 s3_global_k=7,
                 s3_depth=5,
                 s4_emb_dim=512,
                 s4_patch_size=2,
                 s4_local_patch_size=7,
                 s4_global_k=7,
                 s4_depth=4,
                 peg_kernel_size=3,
                 dropout=0.):
        super().__init__()
        kwargs = dict(locals())

        dim = 3
        layers = []

        for prefix in ('s1', 's2', 's3', 's4'):
            config, kwargs = group_by_key_prefix_and_remove_prefix(
                f'{prefix}_', kwargs)
            is_last = prefix == 's4'

            dim_next = config['emb_dim']

            layers.append(
                nn.Sequential(
                    PatchEmbedding(dim=dim,
                                   dim_out=dim_next,
                                   patch_size=config['patch_size']),
                    Transformer(dim=dim_next,
                                depth=1,
                                local_patch_size=config['local_patch_size'],
                                global_k=config['global_k'],
                                dropout=dropout,
                                has_local=not is_last),
                    PEG(dim=dim_next, kernel_size=peg_kernel_size),
                    Transformer(dim=dim_next,
                                depth=config['depth'],
                                local_patch_size=config['local_patch_size'],
                                global_k=config['global_k'],
                                dropout=dropout,
                                has_local=not is_last)))

            dim = dim_next

        self.layers = nn.Sequential(*layers, nn.AdaptiveAvgPool2d(1),
                                    Rearrange('... () () -> ...'),
                                    nn.Linear(dim, num_classes))
Exemplo n.º 27
0
    def __init__(
        self,
        *,
        dim=(64, 128, 256, 512),
        depth=(2, 2, 8, 2),
        window_size=7,
        num_classes=1000,
        tokenize_local_3_conv=False,
        local_patch_size=4,
        use_peg=False,
        attn_dropout=0.,
        ff_dropout=0.,
        channels=3,
    ):
        super().__init__()
        dim = cast_tuple(dim, 4)
        depth = cast_tuple(depth, 4)
        assert len(
            dim) == 4, 'dim needs to be a single value or a tuple of length 4'
        assert len(
            depth
        ) == 4, 'depth needs to be a single value or a tuple of length 4'

        self.local_patch_size = local_patch_size

        region_patch_size = local_patch_size * window_size
        self.region_patch_size = local_patch_size * window_size

        init_dim, *_, last_dim = dim

        # local and region encoders

        if tokenize_local_3_conv:
            self.local_encoder = nn.Sequential(
                nn.Conv2d(3, init_dim, 3, 2, 1), nn.LayerNorm(init_dim),
                nn.GELU(), nn.Conv2d(init_dim, init_dim, 3, 2, 1),
                nn.LayerNorm(init_dim), nn.GELU(),
                nn.Conv2d(init_dim, init_dim, 3, 1, 1))
        else:
            self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3)

        self.region_encoder = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w',
                      p1=region_patch_size,
                      p2=region_patch_size),
            nn.Conv2d((region_patch_size**2) * channels, init_dim, 1))

        # layers

        current_dim = init_dim
        self.layers = nn.ModuleList([])

        for ind, dim, num_layers in zip(range(4), dim, depth):
            not_first = ind != 0
            need_downsample = not_first
            need_peg = not_first and use_peg

            self.layers.append(
                nn.ModuleList([
                    Downsample(current_dim, dim)
                    if need_downsample else nn.Identity(),
                    PEG(dim) if need_peg else nn.Identity(),
                    R2LTransformer(dim,
                                   depth=num_layers,
                                   window_size=window_size,
                                   attn_dropout=attn_dropout,
                                   ff_dropout=ff_dropout)
                ]))

            current_dim = dim

        # final logits

        self.to_logits = nn.Sequential(Reduce('b c h w -> b c', 'mean'),
                                       nn.LayerNorm(last_dim),
                                       nn.Linear(last_dim, num_classes))
Exemplo n.º 28
0
    def __init__(self,
                 *,
                 image_size,
                 num_classes,
                 dim,
                 depth=None,
                 heads=None,
                 mlp_dim=None,
                 pool='cls',
                 channels=3,
                 dim_head=64,
                 dropout=0.,
                 emb_dropout=0.,
                 transformer=None,
                 t2t_layers=((7, 4), (3, 2), (3, 2))):
        super().__init__()
        assert pool in {
            'cls', 'mean'
        }, 'pool type must be either cls (cls token) or mean (mean pooling)'

        layers = []
        layer_dim = channels
        output_image_size = image_size

        for i, (kernel_size, stride) in enumerate(t2t_layers):
            layer_dim *= kernel_size**2
            is_first = i == 0
            is_last = i == (len(t2t_layers) - 1)
            output_image_size = conv_output_size(output_image_size,
                                                 kernel_size, stride,
                                                 stride // 2)

            layers.extend([
                RearrangeImage() if not is_first else nn.Identity(),
                nn.Unfold(kernel_size=kernel_size,
                          stride=stride,
                          padding=stride // 2),
                Rearrange('b c n -> b n c'),
                Transformer(dim=layer_dim,
                            heads=1,
                            depth=1,
                            dim_head=layer_dim,
                            mlp_dim=layer_dim,
                            dropout=dropout) if not is_last else nn.Identity(),
            ])

        layers.append(nn.Linear(layer_dim, dim))
        self.to_patch_embedding = nn.Sequential(*layers)

        self.pos_embedding = nn.Parameter(
            torch.randn(1, output_image_size**2 + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        if not exists(transformer):
            assert all([exists(depth),
                        exists(heads),
                        exists(mlp_dim)
                        ]), 'depth, heads, and mlp_dim must be supplied'
            self.transformer = Transformer(dim, depth, heads, dim_head,
                                           mlp_dim, dropout)
        else:
            self.transformer = transformer

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(nn.LayerNorm(dim),
                                      nn.Linear(dim, num_classes))
Exemplo n.º 29
0
    def __init__(
            self,
            image_size,
            patch_size,
            num_classes,  #一共有多少类别
            Dim,
            depth,
            heads,
            mlp_dim,
            pool='cls',
            channels=3,
            dim_head=64,
            dropout=0.,
            emb_dropout=0.1):
        super(VisionTransformer, self).__init__()
        image_height, image_width = pair(
            image_size)  #image_size=256 -> image_height, image_width = 256
        patch_height, patch_width = pair(
            patch_size)  #patch_size=32 -> patch_height, patch_width = 32

        #图像尺寸和patch尺寸必须要整除,否则报错
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        #计算出一张图可以分成多少patch;这里:8*8=64,即一张图分成了64个patch
        num_patches = (image_height // patch_height) * (image_width //
                                                        patch_width)
        patch_dim = channels * patch_height * patch_width  #3*32*32=3*1024=3072,计算压成一维所需多少容量
        #print(patch_dim)

        assert pool in {
            'cls', 'mean'
        }, 'pool type must be either cls (cls token) or mean (mean pooling)'

        #将一整张图变成patch embeding
        self.to_patch_embedding = nn.Sequential(
            #https://blog.csdn.net/csdn_yi_e/article/details/109143580
            #按给出的模式(注释)重组张量,其中模式中字母只是个表示,没有具体含义
            Rearrange(
                'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                p1=patch_height,
                p2=patch_width
            ),  #torch.Size([B, 64, 3072]),一张图像分成了64个patch,并将每个patch压成1维
            nn.Linear(
                patch_dim, Dim
            ),  #torch.Size([B, 64, 1024]),线性投影,将每个patch降维到1024维,得到patch embeddings
        )

        self.cls_token = nn.Parameter(torch.randn(
            1, 1, Dim))  #torch.Size([1, 1, 1024])
        #print(self.cls_token.shape)
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Encoder(
            num_layers=depth,  #6
            mlp_dim=mlp_dim,  #2048
            dropout_rate=dropout,  #0.1
            heads=heads,  #16
            dim_head=dim_head,  #64
            Dim=Dim  #1024
        )

        self.pool = pool
        self.to_latent = nn.Identity(
        )  #https://blog.csdn.net/artistkeepmonkey/article/details/115067356

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(Dim),
            nn.Linear(Dim, num_classes)  ##torch.Size([B, 1024]) -->
        )
 def __init__(self, in_channels: int, patch_size: tuple, embbeding_dim: int):
     super(LinePatchEmbbeding, self).__init__()
     self.rearrange = Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1 = patch_size[0], s2 = patch_size[1])
     self.linear    = nn.Linear(patch_size[0] * patch_size[1] * in_channels, embbeding_dim)