def __init__(self,
              dim: tuple = (4, 4),
              wave_name='haar',
              init_mode='he_normal'):
     super(PrimaryDWT, self).__init__()
     self.caps_conv = CapsConv2d(dim)
     self.dwt = DWTForward(wave_name)
    def _make_layer(self, down_sample_times: int, channel: int):
        layer_list = []
        for i in range(down_sample_times):
            layer_list.append(nn.Conv2d(channel, channel // 4, 1))
            layer_list.append(DWTForward())
            layer_list.append(nn.BatchNorm2d(channel))

        return nn.Sequential(*layer_list)
Example #3
0
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlockDWT, self).__init__()
        self.in_planes = in_planes
        self.planes = planes
        self.stride = stride
        self.conv1 = nn.Conv2d(self.in_planes,
                               self.planes,
                               kernel_size=3,
                               stride=self.stride,
                               padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(self.planes)
        self.conv2 = nn.Conv2d(self.planes,
                               self.planes,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(self.planes)
        if self.stride == 2:
            self.conv1 = DWTForward()
            self.bn1 = nn.BatchNorm2d(self.in_planes * 4)
            self.conv2 = nn.Conv2d(self.in_planes * 4,
                                   self.planes,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   bias=False)
            self.bn2 = nn.BatchNorm2d(self.planes)
        else:
            self.conv1 = nn.Conv2d(self.in_planes,
                                   self.planes,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   bias=False)
            self.bn1 = nn.BatchNorm2d(self.planes)
            self.conv2 = nn.Conv2d(self.planes,
                                   self.planes,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1,
                                   bias=False)
            self.bn2 = nn.BatchNorm2d(self.planes)

        self.relu = nn.ReLU()
        self.shortcut = nn.Sequential()
        if self.stride != 1 or self.in_planes != self.expansion * self.planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(self.in_planes,
                          self.expansion * self.planes,
                          kernel_size=1,
                          stride=self.stride,
                          bias=False),
                nn.BatchNorm2d(self.expansion * self.planes))