def _make_layer(self, block, planes, blocks, stride=1):
        """Create sequential layers in a stage.

        Arguments:
            blocks: Resnet block to use.
            planes: Number of channels.
            blocks: Number of blocks in this stage.
            stride: Stride for the first layer in the stage."""

        if blocks == 0:
            return nn.Sequential(nn.Identity())
        norm_layer = self._norm_layer
        downsample = None
        if stride != 1:
            downsample = nn.Sequential(
                nn.AvgPool2d(2, stride),
                SpectralNorm(conv1x1(self.inplanes, planes * block.expansion)),
                norm_layer(planes * block.expansion),
            )
        elif self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                SpectralNorm(conv1x1(self.inplanes, planes * block.expansion, stride)),
                norm_layer(planes * block.expansion),
            )

        layers = [block(self.inplanes, planes, stride, downsample, norm_layer)]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, norm_layer=norm_layer))

        return nn.Sequential(*layers)
Exemplo n.º 2
0
    def _make_layer(self, block, planes, blocks, stride=1):
        if blocks == 0:
            return nn.Sequential(nn.Identity())
        norm_layer = self._norm_layer
        upsample = None
        if stride != 1:
            upsample = nn.Sequential(
                nn.UpsamplingNearest2d(scale_factor=2),
                SpectralNorm(
                    conv1x1(self.inplanes * self.enc_expansion,
                            planes * block.expansion)),
                norm_layer(planes * block.expansion),
            )
        elif self.inplanes != planes * block.expansion:
            upsample = nn.Sequential(
                SpectralNorm(
                    conv1x1(self.inplanes * self.enc_expansion,
                            planes * block.expansion)),
                norm_layer(planes * block.expansion),
            )

        layers = [
            block(self.inplanes, planes, stride, upsample, norm_layer,
                  self.large_kernel, self.enc_expansion)
        ]
        self.inplanes = planes * block.expansion
        self.enc_expansion = 1
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      norm_layer=norm_layer,
                      large_kernel=self.large_kernel))

        return nn.Sequential(*layers)
Exemplo n.º 3
0
 def __init__(self,
              inplanes,
              planes,
              stride=1,
              upsample=None,
              norm_layer=None,
              large_kernel=False,
              enc_expansion=1):
     super(BasicBlock, self).__init__()
     if norm_layer is None:
         norm_layer = nn.BatchNorm2d
     self.stride = stride
     conv = conv5x5 if large_kernel else conv3x3
     # Both self.conv1 and self.downsample layers downsample the input when stride != 1
     if self.stride > 1:
         self.conv1 = SpectralNorm(
             nn.ConvTranspose2d(inplanes * enc_expansion,
                                inplanes,
                                kernel_size=4,
                                stride=2,
                                padding=1,
                                bias=False))
     else:
         self.conv1 = SpectralNorm(conv(inplanes * enc_expansion, inplanes))
     self.bn1 = norm_layer(inplanes)
     self.activation = nn.LeakyReLU(0.2, inplace=True)
     self.conv2 = SpectralNorm(conv(inplanes, planes))
     self.bn2 = norm_layer(planes)
     self.upsample = upsample
Exemplo n.º 4
0
 def _make_shortcut(self, inplane, planes):
     return nn.Sequential(
         SpectralNorm(nn.Conv2d(inplane, planes, kernel_size=3, padding=1, bias=False)),
         nn.ReLU(inplace=True),
         self._norm_layer(planes),
         SpectralNorm(nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)),
         nn.ReLU(inplace=True),
         self._norm_layer(planes)
     )
Exemplo n.º 5
0
 def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None):
     super(BasicBlock, self).__init__()
     if norm_layer is None:
         norm_layer = nn.BatchNorm2d
     # Both self.conv1 and self.downsample layers downsample the input when stride != 1
     self.conv1 = SpectralNorm(conv3x3(inplanes, planes, stride))
     self.bn1 = norm_layer(planes)
     self.activation = nn.ReLU(inplace=True)
     self.conv2 = SpectralNorm(conv3x3(planes, planes))
     self.bn2 = norm_layer(planes)
     self.downsample = downsample
     self.stride = stride
    def __init__(self, block, layers, norm_layer=None, late_downsample=False):
        """Initialize the module.

        Arguments:
            block: Basic resnet block to use.
            layers: List of number of layers to use in each stage.
            norm_layer: Normalization layer to use.
            late_downsample: Set to true if the first downsampling operation should be done one stage late."""

        super(ResNet_D, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.late_downsample = late_downsample
        self.midplanes = 64 if late_downsample else 32
        self.start_stride = [1, 2, 1, 2] if late_downsample else [2, 1, 2, 1]
        self.conv1 = SpectralNorm(nn.Conv2d(3 + 3, 32, kernel_size=3,
                                            stride=self.start_stride[0], padding=1, bias=False))
        self.conv2 = SpectralNorm(nn.Conv2d(32, self.midplanes, kernel_size=3, stride=self.start_stride[1], padding=1,
                                            bias=False))
        self.conv3 = SpectralNorm(nn.Conv2d(self.midplanes, self.inplanes, kernel_size=3, stride=self.start_stride[2],
                                            padding=1, bias=False))
        self.bn1 = norm_layer(32)
        self.bn2 = norm_layer(self.midplanes)
        self.bn3 = norm_layer(self.inplanes)
        self.activation = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=self.start_stride[3])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer_bottleneck = self._make_layer(block, 512, layers[3], stride=2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight_bar)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to [https://arxiv.org/abs/1706.02677].

        for m in self.modules():
            if isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)

        self.conv1.module.weight_bar.data[:, 3:, :, :] = 0
Exemplo n.º 7
0
    def __init__(self, kernel_size=3, features=32, stride=2, sigma=1):
        """Initialize the module.

        Arguments:
            kernel_size: Kernel size of the convolutions.
            features: Number of channels for the intermediate feature maps.
            stride: Stride for the pooling operation.
            sigma: Standard deviation of the normal distribution that is needed for the calculation of the gradient of the log-likelihood."""

        super(SpectralRIM, self).__init__()
        self.sigma = sigma
        input_nc = 7  # RGB FG + RGB BG + Alpha

        padding = (kernel_size -
                   1) // 2  # Calculate padding based on the kernel size.
        # The pooling operation is a strided convolution.
        pool = lambda x, n: SpectralNorm(
            nn.Conv2d(x,
                      n,
                      kernel_size=kernel_size,
                      stride=stride,
                      padding=padding,
                      bias=False))
        # No normalization, but using tanh as activation.
        norm = lambda x: nn.Sequential(nn.Tanh())
        # The unpooling operation is a transposed convolution.
        unpool = lambda x, n: SpectralNorm(
            nn.ConvTranspose2d(x,
                               n,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               output_padding=1))

        # The rnn part of the network. Input -> pool -> norm -> ConvGRU -> unpool -> norm -> ConvGRU.
        # This means there is one ConvGRU at half spatial size and one at full spatial size.
        self.rnn = MultiRNN([
            EmbeddingWrapper(ConvGRU(features, 4 * features),
                             pool(2 * input_nc, features), norm(features))
        ] + [
            EmbeddingWrapper(ConvGRU(features, 4 * features),
                             unpool(4 * features, features), norm(features))
        ])
        # Final convolution at the end to reduce the number of channels back to the number of input channels.
        self.out = nn.Conv2d(4 * features,
                             input_nc,
                             kernel_size=kernel_size,
                             padding=padding,
                             bias=False)
    def _make_shortcut(self, inplane, planes):
        """Create shortcut layer.

        Arguments:
            inplane: Number of input channels.
            planes: Number of output channels."""

        return nn.Sequential(
            SpectralNorm(nn.Conv2d(inplane, planes, kernel_size=3, padding=1, bias=False)),
            nn.ReLU(inplace=True),
            self._norm_layer(planes),
            SpectralNorm(nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)),
            nn.ReLU(inplace=True),
            self._norm_layer(planes)
        )
Exemplo n.º 9
0
    def __init__(self,
                 block,
                 layers,
                 norm_layer=None,
                 large_kernel=False,
                 late_downsample=False):
        super(ResNet_D_Dec, self).__init__()
        self.logger = logging.getLogger("Logger")
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.large_kernel = large_kernel
        self.kernel_size = 5 if self.large_kernel else 3

        self.inplanes = 512 if layers[0] > 0 else 256
        self.late_downsample = late_downsample
        self.midplanes = 64 if late_downsample else 32

        self.conv1 = SpectralNorm(
            nn.ConvTranspose2d(self.midplanes,
                               32,
                               kernel_size=4,
                               stride=2,
                               padding=1,
                               bias=False))
        self.bn1 = norm_layer(32)
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv2d(32,
                               1,
                               kernel_size=self.kernel_size,
                               stride=1,
                               padding=self.kernel_size // 2)
        self.upsample = nn.UpsamplingNearest2d(scale_factor=2)
        self.tanh = nn.Tanh()
        self.layer1 = self._make_layer(block, 256, layers[0], stride=2)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.layer4 = self._make_layer(block,
                                       self.midplanes,
                                       layers[3],
                                       stride=2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if hasattr(m, "weight_bar"):
                    nn.init.xavier_uniform_(m.weight_bar)
                else:
                    nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        for m in self.modules():
            if isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)

        self.logger.debug(self)
Exemplo n.º 10
0
    def __init__(self, block, layers, norm_layer=None, late_downsample=False):
        super(ResNet_D, self).__init__()
        self.logger = logging.getLogger("Logger")
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.late_downsample = late_downsample
        self.midplanes = 64 if late_downsample else 32
        self.start_stride = [1, 2, 1, 2] if late_downsample else [2, 1, 2, 1]
        self.conv1 = SpectralNorm(nn.Conv2d(3 + CONFIG.model.mask_channel, 32, kernel_size=3,
                                            stride=self.start_stride[0], padding=1, bias=False))
        self.conv2 = SpectralNorm(nn.Conv2d(32, self.midplanes, kernel_size=3, stride=self.start_stride[1], padding=1,
                                            bias=False))
        self.conv3 = SpectralNorm(nn.Conv2d(self.midplanes, self.inplanes, kernel_size=3, stride=self.start_stride[2],
                                            padding=1, bias=False))
        self.bn1 = norm_layer(32)
        self.bn2 = norm_layer(self.midplanes)
        self.bn3 = norm_layer(self.inplanes)
        self.activation = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=self.start_stride[3])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer_bottleneck = self._make_layer(block, 512, layers[3], stride=2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight_bar)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677

        for m in self.modules():
            if isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)

        self.logger.debug("encoder conv1 weight shape: {}".format(str(self.conv1.module.weight_bar.data.shape)))
        self.conv1.module.weight_bar.data[:,3:,:,:] = 0

        self.logger.debug(self)
    def __init__(self, block, layers, norm_layer=None, late_downsample=False):
        """Initialize the module.

        Arguments:
            block: Basic block to use in each stage.
            layers: List of number of layers to use in each stage.
            norm_layer: Type of normalization layer.
            late_downsample: Set to true if the first downsampling operation should be done one stage late."""

        super(ResGuidedCxtAtten, self).__init__(block, layers, norm_layer, late_downsample=late_downsample)
        first_inplane = 3 + 3  # RGB image + 3 channel trimap.
        self.shortcut_inplane = [first_inplane, self.midplanes, 64, 128, 256]
        self.shortcut_plane = [32, self.midplanes, 64, 128, 256]

        self.shortcut = nn.ModuleList()
        for stage, inplane in enumerate(self.shortcut_inplane):
            self.shortcut.append(self._make_shortcut(inplane, self.shortcut_plane[stage]))

        self.guidance_head = nn.Sequential(
            nn.ReflectionPad2d(1),
            SpectralNorm(nn.Conv2d(3, 16, kernel_size=3, padding=0, stride=2, bias=False)),
            nn.ReLU(inplace=True),
            self._norm_layer(16),
            nn.ReflectionPad2d(1),
            SpectralNorm(nn.Conv2d(16, 32, kernel_size=3, padding=0, stride=2, bias=False)),
            nn.ReLU(inplace=True),
            self._norm_layer(32),
            nn.ReflectionPad2d(1),
            SpectralNorm(nn.Conv2d(32, 128, kernel_size=3, padding=0, stride=2, bias=False)),
            nn.ReLU(inplace=True),
            self._norm_layer(128)
        )

        self.gca = GuidedCxtAtten(128, 128)

        # Initialize the guidance head.
        for layers in range(len(self.guidance_head)):
            m = self.guidance_head[layers]
            if isinstance(m, nn.Conv2d):
                if hasattr(m, "weight_bar"):
                    nn.init.xavier_uniform_(m.weight_bar)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
Exemplo n.º 12
0
 def __init__(self,
              inplanes,
              planes,
              stride=1,
              downsample=None,
              norm_layer=None):
     super(Bottleneck, self).__init__()
     if norm_layer is None:
         norm_layer = nn.BatchNorm2d
     width = planes
     # Both self.conv2 and self.downsample layers downsample the input when stride != 1
     self.conv1 = SpectralNorm(conv1x1(inplanes, width))
     self.bn1 = norm_layer(width)
     self.conv2 = SpectralNorm(conv3x3(width, width, stride))
     self.bn2 = norm_layer(width)
     self.conv3 = SpectralNorm(conv1x1(width, planes * self.expansion))
     self.bn3 = norm_layer(planes * self.expansion)
     self.activation = nn.ReLU(inplace=True)
     # self.activation = nn.LeakyReLU(0.2, inplace=True)
     self.downsample = downsample
     self.stride = stride
    def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None):
        """Initialize the module.

        Arguments:
            inplanes: Number of input channels.
            planes: Number of output channels.
            stride: Convolution stride for the first convolution.
            downsample: Downsampling block.
            norm_layer: Normalization layer."""

        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        # Both self.conv1 and self.downsample layers downsample the input when stride != 1.
        self.conv1 = SpectralNorm(conv3x3(inplanes, planes, stride))
        self.bn1 = norm_layer(planes)
        self.activation = nn.ReLU(inplace=True)
        self.conv2 = SpectralNorm(conv3x3(planes, planes))
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride
    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 upsample=None,
                 norm_layer=None,
                 large_kernel=False):
        """Initialize the module.

        Arguments:
            inplanes: Number of input channels.
            planes: Number of output channels.
            stride: Convolution stride for the first convolution.
            upsample: Upsampling block.
            norm_layer: Normalization layer.
            large_kernel: Set to true if a large convolutional kernel should be used."""

        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self.stride = stride
        conv = conv5x5 if large_kernel else conv3x3
        # Both self.conv1 and self.upsample layers upsample the input when stride != 1
        if self.stride > 1:
            self.conv1 = SpectralNorm(
                nn.ConvTranspose2d(inplanes,
                                   inplanes,
                                   kernel_size=4,
                                   stride=2,
                                   padding=1,
                                   bias=False))
        else:
            self.conv1 = SpectralNorm(conv(inplanes, inplanes))
        self.bn1 = norm_layer(inplanes)
        self.activation = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = SpectralNorm(conv(inplanes, planes))
        self.bn2 = norm_layer(planes)
        self.upsample = upsample
    def _make_layer(self, block, planes, blocks, stride=1):
        """Create sequential layers in a stage.

        Arguments:
            blocks: Resnet block to use.
            planes: Number of channels.
            blocks: Number of blocks in this stage.
            stride: Stride for the first layer in the stage."""

        if blocks == 0:
            return nn.Sequential(nn.Identity())
        norm_layer = self._norm_layer
        upsample = None
        if stride != 1:
            upsample = nn.Sequential(
                nn.UpsamplingNearest2d(scale_factor=2),
                SpectralNorm(conv1x1(self.inplanes, planes * block.expansion)),
                norm_layer(planes * block.expansion),
            )
        elif self.inplanes != planes * block.expansion:
            upsample = nn.Sequential(
                SpectralNorm(conv1x1(self.inplanes, planes * block.expansion)),
                norm_layer(planes * block.expansion),
            )

        layers = [
            block(self.inplanes, planes, stride, upsample, norm_layer,
                  self.large_kernel)
        ]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      norm_layer=norm_layer,
                      large_kernel=self.large_kernel))

        return nn.Sequential(*layers)
Exemplo n.º 16
0
    def _make_layer(self, block, planes, blocks, stride=1):
        if blocks == 0:
            return nn.Sequential(nn.Identity())
        norm_layer = self._norm_layer
        downsample = None
        if stride != 1:
            downsample = nn.Sequential(
                nn.AvgPool2d(2, stride),
                SpectralNorm(conv1x1(self.inplanes, planes * block.expansion)),
                norm_layer(planes * block.expansion),
            )
        elif self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                SpectralNorm(conv1x1(self.inplanes, planes * block.expansion, stride)),
                norm_layer(planes * block.expansion),
            )

        layers = [block(self.inplanes, planes, stride, downsample, norm_layer)]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, norm_layer=norm_layer))

        return nn.Sequential(*layers)
Exemplo n.º 17
0
    def __init__(self, block, layers, norm_layer=None, late_downsample=False):
        super(ResGuidedCxtAtten, self).__init__(block, layers, norm_layer, late_downsample=late_downsample)
        first_inplane = 3 + CONFIG.model.trimap_channel
        self.shortcut_inplane = [first_inplane, self.midplanes, 64, 128, 256]
        self.shortcut_plane = [32, self.midplanes, 64, 128, 256]

        self.shortcut = nn.ModuleList()
        for stage, inplane in enumerate(self.shortcut_inplane):
            self.shortcut.append(self._make_shortcut(inplane, self.shortcut_plane[stage]))

        self.guidance_head = nn.Sequential(
            nn.ReflectionPad2d(1),
            SpectralNorm(nn.Conv2d(3, 16, kernel_size=3, padding=0, stride=2, bias=False)),
            nn.ReLU(inplace=True),
            self._norm_layer(16),
            nn.ReflectionPad2d(1),
            SpectralNorm(nn.Conv2d(16, 32, kernel_size=3, padding=0, stride=2, bias=False)),
            nn.ReLU(inplace=True),
            self._norm_layer(32),
            nn.ReflectionPad2d(1),
            SpectralNorm(nn.Conv2d(32, 128, kernel_size=3, padding=0, stride=2, bias=False)),
            nn.ReLU(inplace=True),
            self._norm_layer(128)
        )

        self.gca = GuidedCxtAtten(128, 128)

        # initialize guidance head
        for layers in range(len(self.guidance_head)):
            m = self.guidance_head[layers]
            if isinstance(m, nn.Conv2d):
                if hasattr(m, "weight_bar"):
                    nn.init.xavier_uniform_(m.weight_bar)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
Exemplo n.º 18
0
    def __init__(self, block, layers, norm_layer=None, late_downsample=False):
        super(ResLocalHOP_PosEmb,
              self).__init__(block,
                             layers,
                             norm_layer,
                             late_downsample=late_downsample)
        first_inplane = 3 + CONFIG.model.trimap_channel
        self.shortcut_inplane = [
            first_inplane, self.midplanes, 64 * block.expansion,
            128 * block.expansion, 256 * block.expansion
        ]
        self.shortcut_plane = [32, self.midplanes, 64, 128, 256]

        self.shortcut = nn.ModuleList()
        for stage, inplane in enumerate(self.shortcut_inplane):
            self.shortcut.append(
                self._make_shortcut(inplane, self.shortcut_plane[stage]))

        self.guidance_head1 = nn.Sequential(  # N x 16 x 256 x 256
            nn.ReflectionPad2d(1),
            SpectralNorm(
                nn.Conv2d(3,
                          16,
                          kernel_size=3,
                          padding=0,
                          stride=2,
                          bias=False)),
            nn.ReLU(inplace=True),
            self._norm_layer(16),
        )
        self.guidance_head2 = nn.Sequential(  # N x 32 x 128 x 128
            nn.ReflectionPad2d(1),
            SpectralNorm(
                nn.Conv2d(16,
                          32,
                          kernel_size=3,
                          padding=0,
                          stride=2,
                          bias=False)),
            nn.ReLU(inplace=True),
            self._norm_layer(32),
        )
        self.guidance_head3 = nn.Sequential(  # N x 64 x 64 x 64
            nn.ReflectionPad2d(1),
            SpectralNorm(
                nn.Conv2d(32,
                          64,
                          kernel_size=3,
                          padding=0,
                          stride=2,
                          bias=False)), nn.ReLU(inplace=True),
            self._norm_layer(64))
        self.guidance_head4 = nn.Sequential(  # N x 64 x 32 x 32
            nn.ReflectionPad2d(1),
            SpectralNorm(
                nn.Conv2d(64,
                          64,
                          kernel_size=3,
                          padding=0,
                          stride=2,
                          bias=False)), nn.ReLU(inplace=True),
            self._norm_layer(64))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if hasattr(m, "weight_bar"):
                    nn.init.xavier_uniform_(m.weight_bar)
                else:
                    nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    def __init__(self,
                 block,
                 layers,
                 norm_layer=None,
                 large_kernel=False,
                 late_downsample=False):
        """Initialize the module.

        Arguments:
            block: Basic resnet block to use.
            layers: List of number of layers to use in each stage.
            norm_layer: Normalization layer to use.
            large_kernel: Set to true if a large convolutional kernel should be used.
            late_downsample: Set to true if the first downsampling operation should be done one stage late."""

        super(ResNet_D_Dec, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer
        self.large_kernel = large_kernel
        self.kernel_size = 5 if self.large_kernel else 3

        self.inplanes = 512 if layers[0] > 0 else 256
        self.late_downsample = late_downsample
        self.midplanes = 64 if late_downsample else 32

        self.conv1 = SpectralNorm(
            nn.ConvTranspose2d(self.midplanes,
                               32,
                               kernel_size=4,
                               stride=2,
                               padding=1,
                               bias=False))
        self.bn1 = norm_layer(32)
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv2d(32,
                               1,
                               kernel_size=self.kernel_size,
                               stride=1,
                               padding=self.kernel_size // 2)
        self.upsample = nn.UpsamplingNearest2d(scale_factor=2)
        self.tanh = nn.Tanh()
        self.layer1 = self._make_layer(block, 256, layers[0], stride=2)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.layer4 = self._make_layer(block,
                                       self.midplanes,
                                       layers[3],
                                       stride=2)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if hasattr(m, "weight_bar"):
                    nn.init.xavier_uniform_(m.weight_bar)
                else:
                    nn.init.xavier_uniform_(m.weight)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to [https://arxiv.org/abs/1706.02677].
        for m in self.modules():
            if isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)