예제 #1
0
파일: coat.py 프로젝트: DeH40/CoaT-pytorch
 def __init__(self, depth, dim, heads, dim_head, scale, dropout):
     super(Transformer, self).__init__()
     self.layers = nn.ModuleList([])
     for _ in range(depth):
         self.layers.append(
             nn.ModuleList([
                 PreNorm(dim, ConvAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                 PreNorm(dim, FeedForward(dim, dim*scale, dropout=dropout))
         ]))
예제 #2
0
 def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.):
     super().__init__()
     self.layers = nn.ModuleList([])
     self.layers.append(
         nn.ModuleList([
             PreNorm(
                 dim,
                 LCAttention(dim,
                             heads=heads,
                             dim_head=dim_head,
                             dropout=dropout)),
             PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
         ]))
예제 #3
0
 def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
     super().__init__()
     self.layers = nn.ModuleList([])
     self.norm = nn.LayerNorm(dim)
     for _ in range(depth):
         self.layers.append(
             nn.ModuleList([
                 PreNorm(
                     dim,
                     Attention(dim,
                               heads=heads,
                               dim_head=dim_head,
                               dropout=dropout)),
                 PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
             ]))
예제 #4
0
파일: coat.py 프로젝트: DeH40/CoaT-pytorch
    def __init__(self, in_channels, image_size, num_classes, out_channels=[64, 128, 256, 320], depths=[2, 2, 2, 2],
                 heads=8, scales=[8, 8, 4, 4], downscales=[4, 2, 2, 2], kernels=[7, 3, 3, 3], use_parallel=False,
                 parallel_depth = 6, parallel_channels=152, dropout=0.):
        super(CoaT, self).__init__()

        assert len(out_channels) == len(depths) == len(scales) == len(downscales) == len(kernels)
        feature_size = image_size
        self.cls_token = nn.Parameter(torch.randn(1, 1, in_channels))
        self.serial_layers = nn.ModuleList([])
        for out_channel, depth, scale, downscale, kernel in zip(out_channels, depths, scales, downscales, kernels):
            feature_size = feature_size // downscale
            self.serial_layers.append(
                SerialBlock(feature_size, in_channels, out_channel, depth, heads, scale, kernel, downscale, dropout)
            )
            in_channels = out_channel


        self.use_parallel = use_parallel
        if use_parallel:
            self.parallel_conv_attn = nn.ModuleList([])
            self.parallel_ffn = nn.ModuleList([])
            for _ in range(parallel_depth):
                self.parallel_conv_attn.append(ParallelBlock(parallel_channels, heads, dropout)
                )
                self.parallel_ffn.append(
                        PreNorm(parallel_channels, FeedForward(parallel_channels, parallel_channels * 4, dropout=dropout))
                        )

            self.parallel_mlp_head = nn.Sequential(
                nn.LayerNorm(in_channels*3),
                nn.Linear(in_channels*3, num_classes)
            )



        self.serial_mlp_head = nn.Sequential(
            nn.LayerNorm(in_channels),
            nn.Linear(in_channels, num_classes)
        )