Ejemplo n.º 1
0
 def __init__(self,
              inplanes,
              planes,
              stride=1,
              upsample=None,
              norm_layer=None,
              large_kernel=False):
     super(BasicBlock, self).__init__()
     if norm_layer is None:
         norm_layer = nn.BatchNorm2d
     self.stride = stride
     conv = conv5x5 if large_kernel else conv3x3
     # Both self.conv1 and self.downsample layers downsample the input when stride != 1
     if self.stride > 1:
         self.conv1 = SpectralNorm(
             nn.ConvTranspose2d(inplanes,
                                inplanes,
                                kernel_size=4,
                                stride=2,
                                padding=1,
                                bias=False))
     else:
         self.conv1 = SpectralNorm(conv(inplanes, inplanes))
     self.bn1 = norm_layer(inplanes)
     self.activation = nn.LeakyReLU(0.2, inplace=True)
     self.conv2 = SpectralNorm(conv(inplanes, planes))
     self.bn2 = norm_layer(planes)
     self.upsample = upsample
Ejemplo n.º 2
0
    def _make_layer(self, block, planes, blocks, stride=1, inplane_multi=1):
        if blocks == 0:
            return nn.Sequential(nn.Identity())
        norm_layer = self._norm_layer
        upsample = None
        self.inplanes = int(self.inplanes * inplane_multi)
        if stride != 1:
            upsample = nn.Sequential(
                nn.UpsamplingNearest2d(scale_factor=2),
                SpectralNorm(conv1x1(self.inplanes, planes * block.expansion)),
                norm_layer(planes * block.expansion),
            )
        elif self.inplanes != planes * block.expansion:
            upsample = nn.Sequential(
                SpectralNorm(conv1x1(self.inplanes, planes * block.expansion)),
                norm_layer(planes * block.expansion),
            )

        layers = [
            block(int(self.inplanes), planes, stride, upsample, norm_layer,
                  self.large_kernel)
        ]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      norm_layer=norm_layer,
                      large_kernel=self.large_kernel))

        return nn.Sequential(*layers)
Ejemplo n.º 3
0
 def _make_shortcut(self, inplane, planes):
     return nn.Sequential(
         SpectralNorm(nn.Conv2d(inplane, planes, kernel_size=3, padding=1, bias=False)),
         nn.ReLU(inplace=True),
         self._norm_layer(planes),
         SpectralNorm(nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)),
         nn.ReLU(inplace=True),
         self._norm_layer(planes)
     )
Ejemplo n.º 4
0
    def __init__(self, block, layers, norm_layer=None, late_downsample=False):
        super(ResGuidedCxtAtten,
              self).__init__(block,
                             layers,
                             norm_layer,
                             late_downsample=late_downsample)
        first_inplane = 3 + TRIMAP_CHANNEL
        self.shortcut_inplane = [first_inplane, self.midplanes, 64, 128, 256]
        self.shortcut_plane = [32, self.midplanes, 64, 128, 256]

        self.shortcut = nn.ModuleList()
        for stage, inplane in enumerate(self.shortcut_inplane):
            self.shortcut.append(
                self._make_shortcut(inplane, self.shortcut_plane[stage]))

        self.guidance_head = nn.Sequential(
            nn.ReflectionPad2d(1),
            SpectralNorm(
                nn.Conv2d(3,
                          16,
                          kernel_size=3,
                          padding=0,
                          stride=2,
                          bias=False)), nn.ReLU(inplace=True),
            self._norm_layer(16), nn.ReflectionPad2d(1),
            SpectralNorm(
                nn.Conv2d(16,
                          32,
                          kernel_size=3,
                          padding=0,
                          stride=2,
                          bias=False)), nn.ReLU(inplace=True),
            self._norm_layer(32), nn.ReflectionPad2d(1),
            SpectralNorm(
                nn.Conv2d(32,
                          128,
                          kernel_size=3,
                          padding=0,
                          stride=2,
                          bias=False)), nn.ReLU(inplace=True),
            self._norm_layer(128))

        self.gca = GuidedCxtAtten(128, 128)

        # initialize guidance head
        for layers in range(len(self.guidance_head)):
            m = self.guidance_head[layers]
            if isinstance(m, nn.Conv2d):
                if hasattr(m, "weight_bar"):
                    nn.init.xavier_uniform_(m.weight_bar)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
Ejemplo n.º 5
0
 def __init__(self,
              inplanes,
              planes,
              stride=1,
              downsample=None,
              norm_layer=None):
     super(BasicBlock, self).__init__()
     if norm_layer is None:
         norm_layer = nn.BatchNorm2d
     # Both self.conv1 and self.downsample layers downsample the input when stride != 1
     self.conv1 = SpectralNorm(conv3x3(inplanes, planes, stride))
     self.bn1 = norm_layer(planes)
     self.activation = nn.ReLU(inplace=True)
     self.conv2 = SpectralNorm(conv3x3(planes, planes))
     self.bn2 = norm_layer(planes)
     self.downsample = downsample
     self.stride = stride
Ejemplo n.º 6
0
    def _make_layer(self, block, planes, blocks, stride=1):
        if blocks == 0:
            return nn.Sequential(nn.Identity())
        norm_layer = self._norm_layer
        downsample = None
        if stride != 1:
            downsample = nn.Sequential(
                nn.AvgPool2d(2, stride),
                SpectralNorm(conv1x1(self.inplanes, planes * block.expansion)),
                norm_layer(planes * block.expansion),
            )
        elif self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                SpectralNorm(
                    conv1x1(self.inplanes, planes * block.expansion, stride)),
                norm_layer(planes * block.expansion),
            )

        layers = [block(self.inplanes, planes, stride, downsample, norm_layer)]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, norm_layer=norm_layer))

        return nn.Sequential(*layers)
Ejemplo n.º 7
0
    def __init__(self, block, layers, norm_layer=None, late_downsample=False):
        super(ResNet_D, self).__init__()
        self.logger = logging.getLogger("Logger")
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.late_downsample = late_downsample
        self.midplanes = 64 if late_downsample else 32
        self.start_stride = [1, 2, 1, 2] if late_downsample else [2, 1, 2, 1]
        self.conv1 = SpectralNorm(
            nn.Conv2d(3 + TRIMAP_CHANNEL,
                      32,
                      kernel_size=3,
                      stride=self.start_stride[0],
                      padding=1,
                      bias=False))
        self.conv2 = SpectralNorm(
            nn.Conv2d(32,
                      self.midplanes,
                      kernel_size=3,
                      stride=self.start_stride[1],
                      padding=1,
                      bias=False))
        self.conv3 = SpectralNorm(
            nn.Conv2d(self.midplanes,
                      self.inplanes,
                      kernel_size=3,
                      stride=self.start_stride[2],
                      padding=1,
                      bias=False))
        self.bn1 = norm_layer(32)
        self.bn2 = norm_layer(self.midplanes)
        self.bn3 = norm_layer(self.inplanes)
        self.activation = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block,
                                       64,
                                       layers[0],
                                       stride=self.start_stride[3])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer_bottleneck = self._make_layer(block,
                                                 512,
                                                 layers[3],
                                                 stride=2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight_bar)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677

        for m in self.modules():
            if isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)

        self.logger.debug("encoder conv1 weight shape: {}".format(
            str(self.conv1.module.weight_bar.data.shape)))
        self.conv1.module.weight_bar.data[:, 3:, :, :] = 0

        self.logger.debug(self)
Ejemplo n.º 8
0
    def __init__(self,
                 block,
                 layers,
                 norm_layer=None,
                 large_kernel=False,
                 late_downsample=False,
                 layer_multi=[1, 1, 1]):
        super(ResNet_D_Dec, self).__init__()
        self.logger = logging.getLogger("Logger")
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.large_kernel = large_kernel
        self.kernel_size = 5 if self.large_kernel else 3

        self.inplanes = 512 if layers[0] > 0 else 256
        self.late_downsample = late_downsample
        self.midplanes = 64 if late_downsample else 32

        self.conv1 = SpectralNorm(
            nn.ConvTranspose2d(self.midplanes,
                               32,
                               kernel_size=4,
                               stride=2,
                               padding=1,
                               bias=False))
        self.bn1 = norm_layer(32)
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv2d(32,
                               1,
                               kernel_size=self.kernel_size,
                               stride=1,
                               padding=self.kernel_size // 2)
        self.upsample = nn.UpsamplingNearest2d(scale_factor=2)
        self.tanh = nn.Tanh()
        self.layer1 = self._make_layer(block,
                                       256,
                                       layers[0],
                                       stride=2,
                                       inplane_multi=layer_multi[0])
        self.layer2 = self._make_layer(block,
                                       128,
                                       layers[1],
                                       stride=2,
                                       inplane_multi=layer_multi[1])
        self.layer3 = self._make_layer(block,
                                       64,
                                       layers[2],
                                       stride=2,
                                       inplane_multi=layer_multi[2])
        self.layer4 = self._make_layer(block,
                                       self.midplanes,
                                       layers[3],
                                       stride=2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if hasattr(m, "weight_bar"):
                    nn.init.xavier_uniform_(m.weight_bar)
                else:
                    nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        for m in self.modules():
            if isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)

        self.logger.debug(self)