def _show_batch( self, batch: List[torch.tensor], sample_length: int, mean: Tuple[int, int, int] = DEFAULT_MEAN, std: Tuple[int, int, int] = DEFAULT_STD, ) -> None: """ Display a batch of images. Args: batch: List of sample (clip) tensors sample_length: Number of frames to show for each sample mean: Normalization mean std: Normalization std-dev """ batch_size = len(batch) plt.tight_layout() fig, axs = plt.subplots( batch_size, sample_length, figsize=(4 * sample_length, 3 * batch_size), ) for i, ax in enumerate(axs): if batch_size == 1: clip = batch[0] else: clip = batch[i] clip = Rearrange("c t h w -> t c h w")(clip) if not isinstance(ax, np.ndarray): ax = [ax] for j, a in enumerate(ax): a.axis("off") a.imshow( np.moveaxis(denormalize(clip[j], mean, std).numpy(), 0, -1) ) pass
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, ff_mult=4, channels=3, attn_dim=None, prob_survival=1.): super().__init__() assert ( image_size % patch_size) == 0, 'image size must be divisible by the patch size' dim_ff = dim * ff_mult num_patches = (image_size // patch_size)**2 self.to_patch_embed = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1=patch_size, p2=patch_size), nn.Linear(channels * patch_size**2, dim)) self.prob_survival = prob_survival self.layers = nn.ModuleList([ Residual( PreNorm( dim, gMLPBlock(dim=dim, dim_ff=dim_ff, seq_len=num_patches, attn_dim=attn_dim))) for i in range(depth) ]) self.to_logits = nn.Sequential(nn.LayerNorm(dim), Reduce('b n d -> b d', 'mean'), nn.Linear(dim, num_classes))
def __init__( self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., t2t_layers = ((7, 4), (3, 2), (3, 2))): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_size // patch_size) ** 2 patch_dim = channels * patch_size ** 2 assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' layers = [] layer_dim = channels for i, (kernel_size, stride) in enumerate(t2t_layers): layer_dim *= kernel_size ** 2 is_first = i == 0 layers.extend([ RearrangeImage() if not is_first else nn.Identity(), nn.Unfold(kernel_size = kernel_size, stride = stride, padding = stride // 2), Rearrange('b c n -> b n c'), Transformer(dim = layer_dim, heads = 1, depth = 1, dim_head = layer_dim, mlp_dim = layer_dim, dropout = dropout), ]) layers.append(nn.Linear(layer_dim, dim)) self.to_patch_embedding = nn.Sequential(*layers) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self.pool = pool self.to_latent = nn.Identity() self.mlp_head = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, num_classes) )
def __init__( self, config: Config, forward_model: Optional[ForwardModel] = None, ) -> None: super().__init__() # self.save_hyperparameters() self.config = config # self.config["num_wavelens"] = len( # torch.load(Path("/data-new/alok/laser/data.pt"))["interpolated_wavelength"][0] # ) if forward_model is None: self.forward_model = None else: self.forward_model = forward_model self.forward_model.freeze() self.trunk = nn.Sequential( Rearrange("b c -> b c 1 1"), MLPMixer( in_channels=self.config["num_wavelens"], image_size=1, patch_size=1, num_classes=1_000, dim=512, depth=8, token_dim=256, channel_dim=2048, dropout=0.5, ), nn.Flatten(), ) self.continuous_head = nn.LazyLinear(2) self.discrete_head = nn.LazyLinear(12) # XXX this call *must* happen to initialize the lazy layers _dummy_input = torch.rand(2, self.config["num_wavelens"]) self.forward(_dummy_input)
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_size // patch_size)**2 patch_dim = channels * patch_size**2 assert pool in { 'cls', 'mean' }, 'pool type must be either cls (cls token) or mean (mean pooling)' self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), nn.Linear(patch_dim, dim), ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) self.pool = pool self.to_latent = nn.Identity()
def __init__(self, dim, heads, row_attn=True, col_attn=True, accept_edges=False, global_query_attn=False, **kwargs): super().__init__() assert not (not row_attn and not col_attn), 'row or column attention must be turned on' self.row_attn = row_attn self.col_attn = col_attn self.global_query_attn = global_query_attn self.norm = nn.LayerNorm(dim) self.attn = Attention(dim=dim, heads=heads, **kwargs) self.edges_to_attn_bias = nn.Sequential( nn.Linear(dim, heads, bias=False), Rearrange('b i j h -> b h i j')) if accept_edges else None
def __init__(self, size, patch_size, in_channels, embedding_dim, depth=1, heads=8, dim_head=8, dim_mlp=64): super(ConvTokenConvTransformer, self).__init__() self.size = size self.patch_size = patch_size self.in_channels = in_channels self.embedding_dim = embedding_dim self.depth = depth self.heads = heads self.dim_head = dim_head self.dim_mlp = dim_mlp self.patch_height = int(size[0] / patch_size[0]) self.patch_width = int(size[1] / patch_size[1]) self.embedding = ConvPatchEmbbeding(in_channels=self.in_channels, patch_size=self.patch_size, embbeding_dim=self.embedding_dim) self.layernorm = nn.LayerNorm(self.embedding_dim) self.rearrange = Rearrange('b (h w) d -> b d h w', h=self.patch_height, w=self.patch_width) self.attention = SelfConvTransfomer(size=(self.patch_height, self.patch_width), in_channels=self.embedding_dim, depth=self.depth, heads=self.heads, dim_head=self.dim_head, dim_mlp=self.dim_mlp)
def __init__(self, image_size=224, tokens_type='performer', in_chans=3, embed_dim=768, dropout=0.1, t2t_layers=((7, 4), (3, 2), (3, 2))): super().__init__() layers = [] layer_dim = in_chans output_image_size = image_size for i, (kernel_size, stride) in enumerate(t2t_layers): layer_dim *= kernel_size**2 is_first = i == 0 output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2) layers.extend([ RearrangeImage() if not is_first else nn.Identity(), nn.Unfold(kernel_size=kernel_size, stride=stride, padding=stride // 2), Rearrange('b c n -> b n c'), Transformer(dim=layer_dim, heads=1, depth=1, dim_head=layer_dim, mlp_dim=layer_dim, dropout=dropout) if tokens_type == 'transformer' else Performer( dim=layer_dim, inner_dim=layer_dim, kernel_ratio=0.5), ]) layers.append(nn.Linear(layer_dim, embed_dim)) self.output_image_size = output_image_size self.to_patch_embedding = nn.Sequential(*layers)
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3): super().__init__() image_size_h, image_size_w = pair(image_size) assert image_size_h % patch_size == 0 and image_size_w % patch_size == 0, 'image dimensions must be divisible by the patch size' assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' num_patches = (image_size_h // patch_size) * (image_size_w // patch_size) patch_dim = channels * patch_size ** 2 self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), nn.Linear(patch_dim, dim), ) self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.transformer = transformer self.pool = pool self.to_latent = nn.Identity() self.mlp_head = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, num_classes) )
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3, dim_head=64, dropout=0., emb_dropout=0., use_rotary=True, use_ds_conv=True, use_glu=True): super().__init__() assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_size // patch_size)**2 patch_dim = channels * patch_size**2 self.patch_size = patch_size self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size), nn.Linear(patch_dim, dim), ) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv, use_glu) self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
def G_transformer(incoming, dim, heads, dim_head, dropout, mlp_dim, curr_patchsize, channel, curr_dim, curr_num, num_channels, to_dim, to_sequential=True, use_wscale=True, use_pixelnorm=True): layers = incoming layers += [ #attention input b (h w) (p1 p2 c) Rearrange('b c (p1 h) (p2 w) -> b (h w) (p1 p2 c)', p1=curr_patchsize, p2=curr_patchsize, c=channel, h=curr_num, w=curr_num), # Residual( Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout), # ), # Residual( # FeedForward(dim, mlp_dim, dropout = dropout), FeedForward(dim, mlp_dim, dropout=dropout), # ), # nn.LeakyReLU(negative_slope=0.2), Rearrange('b (h w) (p1 p2 c) -> b c (p1 h) (p2 w)', p1=curr_patchsize, p2=curr_patchsize, c=channel, h=curr_num, w=curr_num) # nn.GELU(), # ) ] # Norm_Linear(dim, to_dim), # Norm_Linear(2048, 2048), # Norm_Linear(2048, to_dim), # Rearrange('b (h w) (p1 p2 c) -> b c (p1 h) (p2 w)', p1 = curr_patchsize, p2 = curr_patchsize,c=num_channels,h=curr_num,w=curr_num)] # layers1=[] # output -->b c h w # he_init(layers[-1], init, param) # init layers # if use_wscale: # for i,value in enumerate(layers): # print('G_Trans i {}'.format(value)) # layers1 += [Trans_WScaleLayer(value)] # layers += [nonlinearity] # if use_batchnorm: # layers += [nn.BatchNorm2d(out_channels)] if use_pixelnorm: layers += [PixelNormLayer()] # layers1 = layers # layers = incoming+layers1 if to_sequential: return nn.Sequential(*layers) else: return layers
def __init__(self, *, image_size, patch_dim, pixel_dim, patch_size, pixel_size, depth, num_classes, heads=8, dim_head=64, ff_dropout=0., attn_dropout=0., unfold_args=None): super().__init__() assert divisible_by( image_size, patch_size), 'image size must be divisible by patch size' assert divisible_by( patch_size, pixel_size), 'patch size must be divisible by pixel size for now' num_patch_tokens = (image_size // patch_size)**2 self.image_size = image_size self.patch_size = patch_size self.patch_tokens = nn.Parameter( torch.randn(num_patch_tokens + 1, patch_dim)) unfold_args = default(unfold_args, (pixel_size, pixel_size, 0)) unfold_args = (*unfold_args, 0) if len(unfold_args) == 2 else unfold_args kernel_size, stride, padding = unfold_args pixel_width = unfold_output_size(patch_size, kernel_size, stride, padding) num_pixels = pixel_width**2 self.to_pixel_tokens = nn.Sequential( Rearrange('b c (h p1) (w p2) -> (b h w) c p1 p2', p1=patch_size, p2=patch_size), nn.Unfold(kernel_size=kernel_size, stride=stride, padding=padding), Rearrange('... c n -> ... n c'), nn.Linear(3 * kernel_size**2, pixel_dim)) self.patch_pos_emb = nn.Parameter( torch.randn(num_patch_tokens + 1, patch_dim)) self.pixel_pos_emb = nn.Parameter(torch.randn(num_pixels, pixel_dim)) layers = nn.ModuleList([]) for _ in range(depth): pixel_to_patch = nn.Sequential( RMSNorm(pixel_dim), Rearrange('... n d -> ... (n d)'), nn.Linear(pixel_dim * num_pixels, patch_dim), ) layers.append( nn.ModuleList([ PreNorm( pixel_dim, Attention(dim=pixel_dim, heads=heads, dim_head=dim_head, dropout=attn_dropout)), PreNorm(pixel_dim, FeedForward(dim=pixel_dim, dropout=ff_dropout)), pixel_to_patch, PreNorm( patch_dim, Attention(dim=patch_dim, heads=heads, dim_head=dim_head, dropout=attn_dropout)), PreNorm(patch_dim, FeedForward(dim=patch_dim, dropout=ff_dropout)), ])) self.layers = layers self.mlp_head = nn.Sequential(RMSNorm(patch_dim), nn.Linear(patch_dim, num_classes))
def __init__( self, image_size=224, patch_size=16, channels=3, embedding_dim=768, hidden_dims=None, conv_patch=False, linear_patch=False, conv_stem=True, conv_stem_original=True, conv_stem_scaled_relu=False, position_embedding_dropout=None, cls_head=True, ): super(EmbeddingStem, self).__init__() assert (sum([conv_patch, conv_stem, linear_patch ]) == 1), "Only one of three modes should be active" image_height, image_width = pair(image_size) patch_height, patch_width = pair(patch_size) assert (image_height % patch_height == 0 and image_width % patch_width == 0), "Image dimensions must be divisible by the patch size." assert not ( conv_stem and cls_head ), "Cannot use [CLS] token approach with full conv stems for ViT" if linear_patch or conv_patch: self.grid_size = ( image_height // patch_height, image_width // patch_width, ) num_patches = self.grid_size[0] * self.grid_size[1] if cls_head: self.cls_token = nn.Parameter(torch.zeros(1, 1, embedding_dim)) num_patches += 1 # positional embedding self.pos_embed = nn.Parameter( torch.zeros(1, num_patches, embedding_dim)) self.pos_drop = nn.Dropout(p=position_embedding_dropout) if conv_patch: self.projection = nn.Sequential( nn.Conv2d( channels, embedding_dim, kernel_size=patch_size, stride=patch_size, ), ) elif linear_patch: patch_dim = channels * patch_height * patch_width self.projection = nn.Sequential( Rearrange( 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width, ), nn.Linear(patch_dim, embedding_dim), ) elif conv_stem: assert (conv_stem_scaled_relu ^ conv_stem_original ), "Can use either the original or the scaled relu stem" if not isinstance(hidden_dims, list): raise ValueError("Cannot create stem without list of sizes") if conv_stem_original: """ Conv stem from https://arxiv.org/pdf/2106.14881.pdf """ hidden_dims.insert(0, channels) modules = [] for i, (in_ch, out_ch) in enumerate( zip(hidden_dims[:-1], hidden_dims[1:])): modules.append( nn.Conv2d( in_ch, out_ch, kernel_size=3, stride=2 if in_ch != out_ch else 1, padding=1, bias=False, ), ) modules.append(nn.BatchNorm2d(out_ch), ) modules.append(nn.ReLU(inplace=True)) modules.append( nn.Conv2d( hidden_dims[-1], embedding_dim, kernel_size=1, stride=1, ), ) self.projection = nn.Sequential(*modules) elif conv_stem_scaled_relu: """ Conv stem from https://arxiv.org/pdf/2109.03810.pdf """ assert (len(hidden_dims) == 1 ), "Only one value for hidden_dim is allowed" mid_ch = hidden_dims[0] # fmt: off self.projection = nn.Sequential( nn.Conv2d( channels, mid_ch, kernel_size=7, stride=2, padding=3, bias=False, ), nn.BatchNorm2d(mid_ch), nn.ReLU(inplace=True), nn.Conv2d( mid_ch, mid_ch, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(mid_ch), nn.ReLU(inplace=True), nn.Conv2d( mid_ch, mid_ch, kernel_size=3, stride=1, padding=1, bias=False, ), nn.BatchNorm2d(mid_ch), nn.ReLU(inplace=True), nn.Conv2d( mid_ch, embedding_dim, kernel_size=patch_size // 2, stride=patch_size // 2, ), ) # fmt: on else: raise ValueError("Undefined convolutional stem type defined") self.conv_stem = conv_stem self.conv_patch = conv_patch self.linear_patch = linear_patch self.cls_head = cls_head self._init_weights()
def __init__(self, *, image_size, patch_size, num_classes, dim, heads, num_hierarchies, block_repeats, mlp_mult=4, channels=3, dim_head=64, dropout=0.): super().__init__() assert (image_size % patch_size ) == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (image_size // patch_size)**2 patch_dim = channels * patch_size**2 fmap_size = image_size // patch_size blocks = 2**(num_hierarchies - 1) seq_len = ( fmap_size // blocks)**2 # sequence length is held constant across heirarchy hierarchies = list(reversed(range(num_hierarchies))) mults = [2**i for i in reversed(hierarchies)] layer_heads = list(map(lambda t: t * heads, mults)) layer_dims = list(map(lambda t: t * dim, mults)) last_dim = layer_dims[-1] layer_dims = [*layer_dims, layer_dims[-1]] dim_pairs = zip(layer_dims[:-1], layer_dims[1:]) self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (p1 p2 c) h w', p1=patch_size, p2=patch_size), nn.Conv2d(patch_dim, layer_dims[0], 1), ) block_repeats = cast_tuple(block_repeats, num_hierarchies) self.layers = nn.ModuleList([]) for level, heads, (dim_in, dim_out), block_repeat in zip( hierarchies, layer_heads, dim_pairs, block_repeats): is_last = level == 0 depth = block_repeat self.layers.append( nn.ModuleList([ Transformer(dim_in, seq_len, depth, heads, mlp_mult, dropout), Aggregate(dim_in, dim_out) if not is_last else nn.Identity() ])) self.mlp_head = nn.Sequential(LayerNorm(last_dim), Reduce('b c h w -> b c', 'mean'), nn.Linear(last_dim, num_classes))
def DynamicPositionBias(dim): return nn.Sequential(nn.Linear(2, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Linear(dim, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Linear(dim, dim), nn.LayerNorm(dim), nn.ReLU(), nn.Linear(dim, 1), Rearrange('... () -> ...'))
# --------------------------------------------------------' import random import math import numpy as np import os import torchvision.transforms as transforms from einops.layers.torch import Rearrange import torch from PIL import Image input_size = 224 patch_height, patch_width = 16, 16 h, w = int(input_size / patch_height), int(input_size / patch_width) to_patch_embedding_2d = Rearrange( "(h p1) (w p2) -> (h w) (p1 p2)", p1=patch_height, p2=patch_width ) _FOC_MASK_PATH = os.path.join("/disk1/data", "mask", "foc") def haved_masked_lst(fname): mask_fname = os.path.join(_FOC_MASK_PATH, fname) image = Image.open(mask_fname) image = transforms.Resize([input_size, input_size])(image) image_arr = np.array(image) patch_image = to_patch_embedding_2d(torch.tensor(image_arr, dtype=torch.float32)) image.close() index_lst = [] for i in range(h * w): x = np.count_nonzero((patch_image[i, :] > 100)) # 非黑元素
def __init__(self, size): self.rearrange = Rearrange("c (h p1) (w p2) -> (h w) (p1 p2 c)", p1=size, p2=size)
def __init__(self, *, image_size, patch_dim, pixel_dim, patch_size, pixel_size, depth, num_classes, heads=8, dim_head=64, ff_dropout=0., attn_dropout=0.): super().__init__() assert image_size % patch_size == 0, 'image size must be divisible by patch size' assert patch_size % pixel_size == 0, 'patch size must be divisible by pixel size for now' num_patch_tokens = (image_size // patch_size)**2 pixel_width = patch_size // pixel_size num_pixels = pixel_width**2 self.image_size = image_size self.patch_size = patch_size self.patch_tokens = nn.Parameter( torch.randn(num_patch_tokens + 1, patch_dim)) self.to_pixel_tokens = nn.Sequential( Rearrange('b c (h p1) (w p2) -> (b h w) c p1 p2', p1=patch_size, p2=patch_size), nn.Unfold(pixel_size, stride=pixel_size), Rearrange('... c n -> ... n c'), nn.Linear(4 * pixel_size**2, pixel_dim)) self.patch_pos_emb = nn.Parameter( torch.randn(num_patch_tokens + 1, patch_dim)) self.pixel_pos_emb = nn.Parameter(torch.randn(num_pixels, pixel_dim)) layers = nn.ModuleList([]) for _ in range(depth): pixel_to_patch = nn.Sequential( nn.LayerNorm(pixel_dim), Rearrange('... n d -> ... (n d)'), nn.Linear(pixel_dim * num_pixels, patch_dim), ) layers.append( nn.ModuleList([ PreNorm( pixel_dim, Attention(dim=pixel_dim, heads=heads, dim_head=dim_head, dropout=attn_dropout)), PreNorm(pixel_dim, FeedForward(dim=pixel_dim, dropout=ff_dropout)), pixel_to_patch, PreNorm( patch_dim, Attention(dim=patch_dim, heads=heads, dim_head=dim_head, dropout=attn_dropout)), PreNorm(patch_dim, FeedForward(dim=patch_dim, dropout=ff_dropout)), ])) self.layers = layers self.mlp_head = nn.Sequential(nn.LayerNorm(patch_dim), nn.Linear(patch_dim, num_classes))
def dataset_with_indices(cls): """ Returns a modified class cls, which returns tuples like (X, y, indices) instead of just (X, y). """ def __getitem__(self, index): data, target = cls.__getitem__(self, index) return data, target, index return type(cls.__name__, (cls, ), {"__getitem__": __getitem__}) FashionMNIST = dataset_with_indices(FashionMNIST) model = nn.Sequential(Rearrange("b () h w -> b (h w)"), nn.Linear(28**2, 10)) DATASET_PATH = "/mnt/hdd_1tb/datasets/fashionmnist" dataset = FashionMNIST(DATASET_PATH, transform=Compose( (ToTensor(), Normalize((0.286, ), (0.353, ))))) TRAIN_SIZE = 50000 BATCH_SIZE = 512 DEVICE = torch.device("cuda") train_dataloader = DataLoader( dataset, BATCH_SIZE, sampler=SubsetRandomSampler(range(TRAIN_SIZE)), pin_memory=(DEVICE.type == "cuda"), drop_last=True, ) val_dataloader = DataLoader(
def __init__( self, *, dim, max_seq_len=2048, depth=6, heads=8, dim_head=64, attn_types=('full', ), num_tokens=constants.NUM_AMINO_ACIDS, num_embedds=constants.NUM_EMBEDDS_TR, max_num_msas=constants.MAX_NUM_MSA, max_num_templates=constants.MAX_NUM_TEMPLATES, attn_dropout=0., ff_dropout=0., reversible=False, sparse_self_attn=False, cross_attn_compress_ratio=1, msa_tie_row_attn=False, template_attn_depth=2, num_backbone_atoms=1, # number of atoms to reconstitute each residue to, defaults to 3 for C, C-alpha, N predict_angles=False, symmetrize_omega=False, predict_coords=False, # structure module related keyword arguments below predict_real_value_distances=False, trunk_embeds_to_se3_edges=0, # feeds pairwise projected logits from the trunk embeddings into the equivariant transformer as edges se3_edges_fourier_encodings=4, # number of fourier encodings for se3 edges return_aux_logits=False, mds_iters=5, use_se3_transformer=True, # uses SE3 Transformer - but if set to false, will use the new E(n)-Transformer structure_module_dim=4, structure_module_depth=4, structure_module_heads=1, structure_module_dim_head=4, structure_module_refinement_iters=2, structure_module_knn=0, structure_module_adj_neighbors=2): super().__init__() assert num_backbone_atoms in { 1, 3, 4 }, 'must be either residue level, or reconstitute to atomic coordinates of 3 for the C, Ca, N of backbone, or 4 of C-beta as well' layers_sparse_attn = cast_tuple(sparse_self_attn, depth) self.token_emb = nn.Embedding(num_tokens, dim) # template embedding self.template_dist_emb = nn.Embedding(constants.DISTOGRAM_BUCKETS, dim) self.template_num_pos_emb = nn.Embedding(max_num_templates, dim) # projection for angles, if needed self.predict_angles = predict_angles self.symmetrize_omega = symmetrize_omega if predict_angles: self.to_prob_theta = nn.Linear(dim, constants.THETA_BUCKETS) self.to_prob_phi = nn.Linear(dim, constants.PHI_BUCKETS) self.to_prob_omega = nn.Linear(dim, constants.OMEGA_BUCKETS) # when predicting the coordinates, whether to return the other logits, distogram (and optionally, angles) self.return_aux_logits = return_aux_logits # template sidechain encoding self.use_se3_transformer = use_se3_transformer if use_se3_transformer: self.template_sidechain_emb = SE3TemplateEmbedder(dim=dim, dim_head=dim, heads=1, num_neighbors=12, depth=4, input_degrees=2, num_degrees=2, output_degrees=1, reversible=True) else: self.template_sidechain_emb = EnTransformer( dim=dim, dim_head=dim, heads=1, num_nearest_neighbors=32, depth=4) # custom embedding projection self.embedd_project = nn.Linear(num_embedds, dim) # main trunk modules prenorm = partial(PreNorm, dim) prenorm_cross = partial(PreNormCross, dim) layers = nn.ModuleList([]) attn_types = islice(cycle(attn_types), depth) for ind, layer_sparse_attn, attn_type in zip(range(depth), layers_sparse_attn, attn_types): # alternate between row and column attention to save memory each layer row_attn = ind % 2 == 0 col_attn = ind % 2 == 1 # self attention, for main sequence, msa, and optionally, templates if attn_type == 'full': tensor_slice = None template_axial_attn = True elif attn_type == 'intra_attn': tensor_slice = None template_axial_attn = False elif attn_type == 'seq_only': tensor_slice = (slice(None), slice(0, 1)) template_axial_attn = False else: raise ValueError(f'cannot find attention type {attn_type}') layers.append( nn.ModuleList([ prenorm( InterceptAxialAttention( tensor_slice, AxialAttention( dim=dim, template_axial_attn=template_axial_attn, seq_len=max_seq_len, heads=heads, dim_head=dim_head, dropout=attn_dropout, sparse_attn=sparse_self_attn, row_attn=row_attn, col_attn=col_attn, rotary_rpe=True))), prenorm( InterceptFeedForward( tensor_slice, ff=FeedForward(dim=dim, dropout=ff_dropout))), prenorm( AxialAttention(dim=dim, seq_len=max_seq_len, heads=heads, dim_head=dim_head, dropout=attn_dropout, tie_row_attn=msa_tie_row_attn, row_attn=row_attn, col_attn=col_attn, rotary_rpe=True)), prenorm(FeedForward(dim=dim, dropout=ff_dropout)), ])) # cross attention, for main sequence -> msa and then msa -> sequence intercept_fn = partial(InterceptAttention, (slice(None), slice(0, 1))) layers.append( nn.ModuleList([ intercept_fn( context=False, attn=prenorm_cross( Attention( dim=dim, seq_len=max_seq_len, heads=heads, dim_head=dim_head, dropout=attn_dropout, compress_ratio=cross_attn_compress_ratio))), prenorm(FeedForward(dim=dim, dropout=ff_dropout)), intercept_fn( context=True, attn=prenorm_cross( Attention( dim=dim, seq_len=max_seq_len, heads=heads, dim_head=dim_head, dropout=attn_dropout, compress_ratio=cross_attn_compress_ratio))), prenorm(FeedForward(dim=dim, dropout=ff_dropout)), ])) if not reversible: layers = nn.ModuleList(list( map(lambda t: t[:3], layers))) # remove last feed forward if not reversible trunk_class = SequentialSequence if not reversible else ReversibleSequence self.net = trunk_class(layers) # to distogram output self.num_backbone_atoms = num_backbone_atoms needs_upsample = num_backbone_atoms > 1 self.predict_real_value_distances = predict_real_value_distances dim_distance_pred = constants.DISTOGRAM_BUCKETS if not predict_real_value_distances else 2 # 2 for predicting mean and standard deviation values of real-value distance self.to_distogram_logits = nn.Sequential( nn.LayerNorm(dim), nn.Sequential(nn.Linear(dim, dim * (num_backbone_atoms**2)), Rearrange('b h w c -> b c h w'), nn.PixelShuffle(num_backbone_atoms), Rearrange('b c h w -> b h w c')) if needs_upsample else nn.Identity(), nn.Linear(dim, dim_distance_pred)) # to coordinate output self.predict_coords = predict_coords self.mds_iters = mds_iters self.structure_module_refinement_iters = structure_module_refinement_iters self.trunk_to_structure_dim = nn.Linear(dim, structure_module_dim) self.trunk_embeds_to_se3_edges = trunk_embeds_to_se3_edges self.se3_edges_fourier_encodings = se3_edges_fourier_encodings edge_dim = ( (2 * se3_edges_fourier_encodings) + 1) * trunk_embeds_to_se3_edges self.to_equivariant_net_edges = nn.Linear( dim, trunk_embeds_to_se3_edges ) if trunk_embeds_to_se3_edges > 0 else None with torch_default_dtype(torch.float64): self.structure_module_embeds = nn.Embedding( num_tokens, structure_module_dim) self.atom_tokens_embed = nn.Embedding(len(ATOM_IDS), structure_module_dim) if use_se3_transformer: self.structure_module = SE3TransformerWrapper( dim=structure_module_dim, depth=structure_module_depth, input_degrees=1, num_degrees=3, output_degrees=2, heads=structure_module_heads, differentiable_coors=True, num_neighbors=0, # use only bonded neighbors for now attend_sparse_neighbors=True, edge_dim=edge_dim, num_adj_degrees=structure_module_adj_neighbors, adj_dim=4, ) else: self.structure_module = EnTransformer( dim=structure_module_dim, depth=structure_module_depth, heads=structure_module_heads, fourier_features=2, num_nearest_neighbors=0, only_sparse_neighbors=True, edge_dim=edge_dim, num_adj_degrees=structure_module_adj_neighbors, adj_dim=4) # aux confidence measure self.lddt_linear = nn.Linear(structure_module_dim, 1)
def __init__( self, num_channels=3, # Overridden based on dataset. resolution=32, # Overridden based on dataset. first_resolution=4, # Overridden based on dataset. label_size=0, # Overridden based on dataset. fmap_base=4096, fmap_decay=1.0, fmap_max=256, dim=3072, #可变参数 base_size=2, max_patch=32, channel=512, heads=8, dim_head=64, mlp_dim=2048, #可变 dropout=0.0, latent_size=None, normalize_latents=True, use_wscale=True, use_pixelnorm=True, use_leakyrelu=True, use_batchnorm=False, tanh_at_end=None): super(Generator, self).__init__() self.num_channels = num_channels self.resolution = resolution self.label_size = label_size self.fmap_base = fmap_base self.fmap_decay = fmap_decay self.fmap_max = fmap_max self.latent_size = latent_size self.normalize_latents = normalize_latents self.use_wscale = use_wscale self.use_pixelnorm = use_pixelnorm self.use_leakyrelu = use_leakyrelu self.use_batchnorm = use_batchnorm self.tanh_at_end = tanh_at_end self.curr_resol = first_resolution R = int(np.log2(resolution)) assert resolution == 2**R and resolution >= 4 if latent_size is None: latent_size = self.get_nf(0) negative_slope = 0.2 act = nn.LeakyReLU(negative_slope=negative_slope ) if self.use_leakyrelu else nn.ReLU() iact = 'leaky_relu' if self.use_leakyrelu else 'relu' output_act = nn.Tanh() if self.tanh_at_end else 'linear' output_iact = 'tanh' if self.tanh_at_end else 'linear' pre = None lods = nn.ModuleList() nins = nn.ModuleList() layers = [] if self.normalize_latents: pre = PixelNormLayer() if self.label_size: layers += [ConcatLayer()] # layers += [ReshapeLayer([int(latent_size/(first_resolution**2)), first_resolution, first_resolution])] # layers = G_conv(layers, latent_size, self.get_nf(1), 4, 3, act, iact, negative_slope, # False, self.use_wscale, self.use_batchnorm, self.use_pixelnorm) #第一层换成MLP 从512--》4096 # layers += [nn.Linear(latent_size, 1024)] layers += [nn.Linear(latent_size, 4096)] #b*4096 --> b*256*4*4 # net = G_conv(layers, latent_size, self.get_nf(1), 3, 1, act, iact, negative_slope, # True, self.use_wscale, self.use_batchnorm, self.use_pixelnorm) # first block # layers += [nn.Linear(4096, 4096)] ## input reshape b * dim --> b c h w curr_patchsize, curr_dim, curr_num = self.get_patch_dim( 1, base_size, max_patch, latent_size // 2, first_resolution) print('curr_dim {}, curr_patchsize {}'.format(curr_dim, curr_patchsize)) ## b * 4096 -->b*256*4*4-->b*(2*2)*(256*2*2) layers += [ Rearrange('b (c p1 h p2 w) -> b c (p1 h) (p2 w)', p1=curr_patchsize, p2=curr_patchsize, c=latent_size // 2, h=curr_num, w=curr_num) ] net = G_transformer(layers, dim, heads, dim_head, dropout, mlp_dim, curr_patchsize, latent_size // 2, curr_dim, curr_num, num_channels, num_channels * curr_patchsize * curr_patchsize) # lods.append(to_patch_embedding) # net = lods.append(net) # nins.append(NINLayer([], self.get_nf(1), self.num_channels, output_act, output_iact, None, True, self.use_wscale)) # to_rgb layer nins.append( NIN_transformer([], dim, heads, dim_head, dropout, mlp_dim, curr_patchsize, num_channels, curr_num, curr_dim, num_channels * curr_patchsize * curr_patchsize)) # nins.append(to_patch_image) for I in range(2, R): # following blocks # ic, oc = self.get_nf(I-1), self.get_nf(I) curr_patchsize, curr_dim, curr_num = self.get_patch_dim( I, base_size, max_patch, num_channels, first_resolution) print('following curr_dim {}, curr_patchsize {}'.format( curr_dim, curr_patchsize)) layers = [nn.Upsample(scale_factor=2, mode='nearest')] # upsample layers = G_transformer(layers, dim, heads, dim_head, dropout, mlp_dim, curr_patchsize, num_channels, curr_dim, curr_num, num_channels, curr_dim) # layers = G_conv(layers, ic, oc, 3, 1, act, iact, negative_slope, False, self.use_wscale, self.use_batchnorm, self.use_pixelnorm) # net = G_conv(layers, oc, oc, 3, 1, act, iact, negative_slope, True, self.use_wscale, self.use_batchnorm, self.use_pixelnorm) net = layers lods.append(net) nins.append( NIN_transformer([], dim, heads, dim_head, dropout, mlp_dim, curr_patchsize, num_channels, curr_num, curr_dim, curr_dim)) # nins.append(NINLayer([], oc, self.num_channels, output_act, output_iact, None, True, self.use_wscale)) # to_rgb layer self.output_layer = GSelectLayer(pre, lods, nins)
def __init__(self, dev): super(scattering, self).__init__() self.rearange0 = Rearrange(' a b c d -> a c (b d)') self.dev = dev
small_pic_batch = torch.ones((1, 1, 6, 6)) small_pic_batch = small_pic_batch.view((1, 1, 6 * 6)) w_query = nn.Linear(6 * 6, 6 * 6, bias=False) w_key = nn.Linear(6 * 6, 6 * 6, bias=False) w_value = nn.Linear(6 * 6, 6 * 6, bias=False) w_query.eval() w_key.eval() w_value.eval() # att soft = nn.Softmax(dim=4) mem_blocks_divider = Rearrange('b c (h p1) (w p2) -> b c (h w) (p1 p2)', p1=2, p2=2) # TODO Optimize query = mem_blocks_divider(w_query(small_pic_batch).view((1, 1, 6, 6))).view( (1, 1, 9, 4, 1)) key = mem_blocks_divider(w_key(small_pic_batch).view((1, 1, 6, 6))) value = mem_blocks_divider(w_value(small_pic_batch).view((1, 1, 6, 6))).view( (1, 1, 9, 4, 1)) # q(i,j) * k(a,b) where i,j are pixel col and row AND a,b are memory block size att_score = soft(torch.einsum('b c n g q, b c n p ->b c n g p', query, key)) # print(torch.matmul(att_score, value).view(1,1,9*4)) # avg pooling
def __init__(self, image_size, channels, num_classes, patch_size_small=14, patch_size_large=16, small_dim=96, large_dim=192, small_depth=1, large_depth=4, cross_attn_depth=1, multi_scale_enc_depth=3, heads=3, pool='cls', dropout=0., emb_dropout=0., scale_dim=4): super().__init__() assert image_size % patch_size_small == 0, 'Image dimensions must be divisible by the patch size.' num_patches_small = (image_size // patch_size_small)**2 patch_dim_small = channels * patch_size_small**2 assert image_size % patch_size_large == 0, 'Image dimensions must be divisible by the patch size.' num_patches_large = (image_size // patch_size_large)**2 patch_dim_large = channels * patch_size_large**2 assert pool in { 'cls', 'mean' }, 'pool type must be either cls (cls token) or mean (mean pooling)' self.to_patch_embedding_small = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size_small, p2=patch_size_small), nn.Linear(patch_dim_small, small_dim), ) self.to_patch_embedding_large = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size_large, p2=patch_size_large), nn.Linear(patch_dim_large, large_dim), ) self.pos_embedding_small = nn.Parameter( torch.randn(1, num_patches_small + 1, small_dim)) self.cls_token_small = nn.Parameter(torch.randn(1, 1, small_dim)) self.dropout_small = nn.Dropout(emb_dropout) self.pos_embedding_large = nn.Parameter( torch.randn(1, num_patches_large + 1, large_dim)) self.cls_token_large = nn.Parameter(torch.randn(1, 1, large_dim)) self.dropout_large = nn.Dropout(emb_dropout) self.multi_scale_transformers = nn.ModuleList([]) for _ in range(multi_scale_enc_depth): self.multi_scale_transformers.append( MultiScaleTransformerEncoder( small_dim=small_dim, small_depth=small_depth, small_heads=heads, small_dim_head=small_dim // heads, small_mlp_dim=small_dim * scale_dim, large_dim=large_dim, large_depth=large_depth, large_heads=heads, large_dim_head=large_dim // heads, large_mlp_dim=large_dim * scale_dim, cross_attn_depth=cross_attn_depth, cross_attn_heads=heads, dropout=dropout)) self.pool = pool self.to_latent = nn.Identity() self.mlp_head_small = nn.Sequential(nn.LayerNorm(small_dim), nn.Linear(small_dim, num_classes)) self.mlp_head_large = nn.Sequential(nn.LayerNorm(large_dim), nn.Linear(large_dim, num_classes))
x_val, y_val, clf=dummy_clf, dataset_name=dataset_name)) print("DONE:", data_dir) return results except: return [] # print("\t", pd.DataFrame(results[2*i:2*i+2])) if __name__ == "__main__": data_dirs = sorted(list(glob.glob('datasets/*'))) transform = tfms.Compose([ tfms.Grayscale(), tfms.Resize(128, interpolation=2), tfms.RandomCrop(112), tfms.ToTensor(), Rearrange('h w c -> (h w c)'), ]) print(len(data_dirs)) result = Parallel(n_jobs=5)(delayed(parfunc)(data_dir) for data_dir in data_dirs) with open("svm_dummy.txt", "wb") as fp: pickle.dump(result, fp) pd.DataFrame(result).to_csv('svm_dummy_results.csv')
def __init__(self, *, num_classes, s1_emb_dim=64, s1_patch_size=4, s1_local_patch_size=7, s1_global_k=7, s1_depth=1, s2_emb_dim=128, s2_patch_size=2, s2_local_patch_size=7, s2_global_k=7, s2_depth=1, s3_emb_dim=256, s3_patch_size=2, s3_local_patch_size=7, s3_global_k=7, s3_depth=5, s4_emb_dim=512, s4_patch_size=2, s4_local_patch_size=7, s4_global_k=7, s4_depth=4, peg_kernel_size=3, dropout=0.): super().__init__() kwargs = dict(locals()) dim = 3 layers = [] for prefix in ('s1', 's2', 's3', 's4'): config, kwargs = group_by_key_prefix_and_remove_prefix( f'{prefix}_', kwargs) is_last = prefix == 's4' dim_next = config['emb_dim'] layers.append( nn.Sequential( PatchEmbedding(dim=dim, dim_out=dim_next, patch_size=config['patch_size']), Transformer(dim=dim_next, depth=1, local_patch_size=config['local_patch_size'], global_k=config['global_k'], dropout=dropout, has_local=not is_last), PEG(dim=dim_next, kernel_size=peg_kernel_size), Transformer(dim=dim_next, depth=config['depth'], local_patch_size=config['local_patch_size'], global_k=config['global_k'], dropout=dropout, has_local=not is_last))) dim = dim_next self.layers = nn.Sequential(*layers, nn.AdaptiveAvgPool2d(1), Rearrange('... () () -> ...'), nn.Linear(dim, num_classes))
def __init__( self, *, dim=(64, 128, 256, 512), depth=(2, 2, 8, 2), window_size=7, num_classes=1000, tokenize_local_3_conv=False, local_patch_size=4, use_peg=False, attn_dropout=0., ff_dropout=0., channels=3, ): super().__init__() dim = cast_tuple(dim, 4) depth = cast_tuple(depth, 4) assert len( dim) == 4, 'dim needs to be a single value or a tuple of length 4' assert len( depth ) == 4, 'depth needs to be a single value or a tuple of length 4' self.local_patch_size = local_patch_size region_patch_size = local_patch_size * window_size self.region_patch_size = local_patch_size * window_size init_dim, *_, last_dim = dim # local and region encoders if tokenize_local_3_conv: self.local_encoder = nn.Sequential( nn.Conv2d(3, init_dim, 3, 2, 1), nn.LayerNorm(init_dim), nn.GELU(), nn.Conv2d(init_dim, init_dim, 3, 2, 1), nn.LayerNorm(init_dim), nn.GELU(), nn.Conv2d(init_dim, init_dim, 3, 1, 1)) else: self.local_encoder = nn.Conv2d(3, init_dim, 8, 4, 3) self.region_encoder = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1=region_patch_size, p2=region_patch_size), nn.Conv2d((region_patch_size**2) * channels, init_dim, 1)) # layers current_dim = init_dim self.layers = nn.ModuleList([]) for ind, dim, num_layers in zip(range(4), dim, depth): not_first = ind != 0 need_downsample = not_first need_peg = not_first and use_peg self.layers.append( nn.ModuleList([ Downsample(current_dim, dim) if need_downsample else nn.Identity(), PEG(dim) if need_peg else nn.Identity(), R2LTransformer(dim, depth=num_layers, window_size=window_size, attn_dropout=attn_dropout, ff_dropout=ff_dropout) ])) current_dim = dim # final logits self.to_logits = nn.Sequential(Reduce('b c h w -> b c', 'mean'), nn.LayerNorm(last_dim), nn.Linear(last_dim, num_classes))
def __init__(self, *, image_size, num_classes, dim, depth=None, heads=None, mlp_dim=None, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0., transformer=None, t2t_layers=((7, 4), (3, 2), (3, 2))): super().__init__() assert pool in { 'cls', 'mean' }, 'pool type must be either cls (cls token) or mean (mean pooling)' layers = [] layer_dim = channels output_image_size = image_size for i, (kernel_size, stride) in enumerate(t2t_layers): layer_dim *= kernel_size**2 is_first = i == 0 is_last = i == (len(t2t_layers) - 1) output_image_size = conv_output_size(output_image_size, kernel_size, stride, stride // 2) layers.extend([ RearrangeImage() if not is_first else nn.Identity(), nn.Unfold(kernel_size=kernel_size, stride=stride, padding=stride // 2), Rearrange('b c n -> b n c'), Transformer(dim=layer_dim, heads=1, depth=1, dim_head=layer_dim, mlp_dim=layer_dim, dropout=dropout) if not is_last else nn.Identity(), ]) layers.append(nn.Linear(layer_dim, dim)) self.to_patch_embedding = nn.Sequential(*layers) self.pos_embedding = nn.Parameter( torch.randn(1, output_image_size**2 + 1, dim)) self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.dropout = nn.Dropout(emb_dropout) if not exists(transformer): assert all([exists(depth), exists(heads), exists(mlp_dim) ]), 'depth, heads, and mlp_dim must be supplied' self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) else: self.transformer = transformer self.pool = pool self.to_latent = nn.Identity() self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))
def __init__( self, image_size, patch_size, num_classes, #一共有多少类别 Dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.1): super(VisionTransformer, self).__init__() image_height, image_width = pair( image_size) #image_size=256 -> image_height, image_width = 256 patch_height, patch_width = pair( patch_size) #patch_size=32 -> patch_height, patch_width = 32 #图像尺寸和patch尺寸必须要整除,否则报错 assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' #计算出一张图可以分成多少patch;这里:8*8=64,即一张图分成了64个patch num_patches = (image_height // patch_height) * (image_width // patch_width) patch_dim = channels * patch_height * patch_width #3*32*32=3*1024=3072,计算压成一维所需多少容量 #print(patch_dim) assert pool in { 'cls', 'mean' }, 'pool type must be either cls (cls token) or mean (mean pooling)' #将一整张图变成patch embeding self.to_patch_embedding = nn.Sequential( #https://blog.csdn.net/csdn_yi_e/article/details/109143580 #按给出的模式(注释)重组张量,其中模式中字母只是个表示,没有具体含义 Rearrange( 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width ), #torch.Size([B, 64, 3072]),一张图像分成了64个patch,并将每个patch压成1维 nn.Linear( patch_dim, Dim ), #torch.Size([B, 64, 1024]),线性投影,将每个patch降维到1024维,得到patch embeddings ) self.cls_token = nn.Parameter(torch.randn( 1, 1, Dim)) #torch.Size([1, 1, 1024]) #print(self.cls_token.shape) self.dropout = nn.Dropout(emb_dropout) self.transformer = Encoder( num_layers=depth, #6 mlp_dim=mlp_dim, #2048 dropout_rate=dropout, #0.1 heads=heads, #16 dim_head=dim_head, #64 Dim=Dim #1024 ) self.pool = pool self.to_latent = nn.Identity( ) #https://blog.csdn.net/artistkeepmonkey/article/details/115067356 self.mlp_head = nn.Sequential( nn.LayerNorm(Dim), nn.Linear(Dim, num_classes) ##torch.Size([B, 1024]) --> )
def __init__(self, in_channels: int, patch_size: tuple, embbeding_dim: int): super(LinePatchEmbbeding, self).__init__() self.rearrange = Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1 = patch_size[0], s2 = patch_size[1]) self.linear = nn.Linear(patch_size[0] * patch_size[1] * in_channels, embbeding_dim)