def __init__(self, depth, dim, heads, dim_head, scale, dropout): super(Transformer, self).__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList([ PreNorm(dim, ConvAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), PreNorm(dim, FeedForward(dim, dim*scale, dropout=dropout)) ]))
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.): super().__init__() self.layers = nn.ModuleList([]) self.layers.append( nn.ModuleList([ PreNorm( dim, LCAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) ]))
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): super().__init__() self.layers = nn.ModuleList([]) self.norm = nn.LayerNorm(dim) for _ in range(depth): self.layers.append( nn.ModuleList([ PreNorm( dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) ]))
def __init__(self, in_channels, image_size, num_classes, out_channels=[64, 128, 256, 320], depths=[2, 2, 2, 2], heads=8, scales=[8, 8, 4, 4], downscales=[4, 2, 2, 2], kernels=[7, 3, 3, 3], use_parallel=False, parallel_depth = 6, parallel_channels=152, dropout=0.): super(CoaT, self).__init__() assert len(out_channels) == len(depths) == len(scales) == len(downscales) == len(kernels) feature_size = image_size self.cls_token = nn.Parameter(torch.randn(1, 1, in_channels)) self.serial_layers = nn.ModuleList([]) for out_channel, depth, scale, downscale, kernel in zip(out_channels, depths, scales, downscales, kernels): feature_size = feature_size // downscale self.serial_layers.append( SerialBlock(feature_size, in_channels, out_channel, depth, heads, scale, kernel, downscale, dropout) ) in_channels = out_channel self.use_parallel = use_parallel if use_parallel: self.parallel_conv_attn = nn.ModuleList([]) self.parallel_ffn = nn.ModuleList([]) for _ in range(parallel_depth): self.parallel_conv_attn.append(ParallelBlock(parallel_channels, heads, dropout) ) self.parallel_ffn.append( PreNorm(parallel_channels, FeedForward(parallel_channels, parallel_channels * 4, dropout=dropout)) ) self.parallel_mlp_head = nn.Sequential( nn.LayerNorm(in_channels*3), nn.Linear(in_channels*3, num_classes) ) self.serial_mlp_head = nn.Sequential( nn.LayerNorm(in_channels), nn.Linear(in_channels, num_classes) )