Exemple #1
0
class ResNet9Stride32(nn.Module):
    def __init__(self, in_channels, channels, down_pad=True):
        super(ResNet9Stride32, self).__init__()

        down_padding = 0
        if down_pad:
            down_padding = 1

        # inchannels, outchannels, kernel size
        self.conv1 = nn.Conv2d(in_channels, channels, 3, stride=2, padding=down_padding)
        self.block1 = ResBlockStrided(channels, stride=2, down_padding=down_padding)
        self.block2 = ResBlockStrided(channels, stride=2, down_padding=down_padding)
        self.block3 = ResBlockStrided(channels, stride=2, down_padding=down_padding)
        self.block4 = ResBlockStrided(channels, stride=2, down_padding=down_padding)

        self.res_norm = nn.InstanceNorm2d(channels)

    def get_downscale_factor(self):
        return 32

    def init_weights(self):
        self.block1.init_weights()
        self.block2.init_weights()
        self.block3.init_weights()
        self.block4.init_weights()
        torch.nn.init.kaiming_uniform_(self.conv1.weight)
        self.conv1.bias.data.fill_(0)

    def forward(self, input):
        x = self.conv1(input)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        return x
Exemple #2
0
    def __init__(self, channels=32, factor=4):
        super(DownsampleResidual, self).__init__()
        pad = 1
        self.factor = factor
        self.channels = channels

        if factor >= 2:
            self.res2 = ResBlockStrided(channels,
                                        stride=2,
                                        down_padding=pad,
                                        nonorm=False)
        if factor >= 4:
            self.res4 = ResBlockStrided(channels,
                                        stride=2,
                                        down_padding=pad,
                                        nonorm=False)
        if factor >= 8:
            self.res8 = ResBlockStrided(channels,
                                        stride=2,
                                        down_padding=pad,
                                        nonorm=False)
        if factor >= 16:
            self.res16 = ResBlockStrided(channels,
                                         stride=2,
                                         down_padding=pad,
                                         nonorm=False)

        self.res_norm = nn.InstanceNorm2d(channels)
Exemple #3
0
class DownsampleResidual(torch.nn.Module):
    """
    Fun class that will repeatedly apply residual blocks with strided convolutions until the
    input image is downsized by the given factor, which must be one of (2, 4, 8, 16).

    """
    def __init__(self, channels=32, factor=4):
        super(DownsampleResidual, self).__init__()
        pad = 1
        self.factor = factor
        self.channels = channels

        if factor >= 2:
            self.res2 = ResBlockStrided(channels,
                                        stride=2,
                                        down_padding=pad,
                                        nonorm=False)
        if factor >= 4:
            self.res4 = ResBlockStrided(channels,
                                        stride=2,
                                        down_padding=pad,
                                        nonorm=False)
        if factor >= 8:
            self.res8 = ResBlockStrided(channels,
                                        stride=2,
                                        down_padding=pad,
                                        nonorm=False)
        if factor >= 16:
            self.res16 = ResBlockStrided(channels,
                                         stride=2,
                                         down_padding=pad,
                                         nonorm=False)

        self.res_norm = nn.InstanceNorm2d(channels)

    def init_weights(self):
        if self.factor >= 2:
            self.res2.init_weights()
        if self.factor >= 4:
            self.res4.init_weights()
        if self.factor >= 8:
            self.res8.init_weights()
        if self.factor >= 16:
            self.res16.init_weights()

    def forward(self, image):
        x = image
        if self.factor >= 2:
            x = self.res2(x)
        if self.factor >= 4:
            x = self.res4(x)
        if self.factor >= 8:
            x = self.res8(x)
        if self.factor >= 16:
            x = self.res16(x)
        return x
Exemple #4
0
    def __init__(self, channels, down_pad=False):
        super(ResNet30, self).__init__()

        down_padding = 0
        if down_pad:
            down_padding = 1

        # inchannels, outchannels, kernel size
        self.conv1 = nn.Conv2d(3, channels, 3, stride=1, padding=down_padding)

        self.block1 = ResBlock(channels)
        self.block2 = ResBlock(channels)
        self.block3 = ResBlock(channels)
        self.block4 = ResBlockStrided(channels,
                                      stride=2,
                                      down_padding=down_padding)

        self.block5 = ResBlock(channels)
        self.block6 = ResBlock(channels)
        self.block7 = ResBlock(channels)
        self.block8 = ResBlockStrided(channels,
                                      stride=2,
                                      down_padding=down_padding)

        self.block9 = ResBlock(channels)
        self.block10 = ResBlock(channels)
        self.block11 = ResBlock(channels)
        self.block12 = ResBlockStrided(channels,
                                       stride=2,
                                       down_padding=down_padding)

        self.res_norm = nn.InstanceNorm2d(channels)
Exemple #5
0
    def __init__(self, in_channels, channels, down_pad=True):
        super(ResNet9Stride32, self).__init__()

        down_padding = 0
        if down_pad:
            down_padding = 1

        # inchannels, outchannels, kernel size
        self.conv1 = nn.Conv2d(in_channels, channels, 3, stride=2, padding=down_padding)
        self.block1 = ResBlockStrided(channels, stride=2, down_padding=down_padding)
        self.block2 = ResBlockStrided(channels, stride=2, down_padding=down_padding)
        self.block3 = ResBlockStrided(channels, stride=2, down_padding=down_padding)
        self.block4 = ResBlockStrided(channels, stride=2, down_padding=down_padding)

        self.res_norm = nn.InstanceNorm2d(channels)
Exemple #6
0
class ResNet13S(CudaModule):
    def __init__(self, channels, down_pad=False):
        super(ResNet13S, self).__init__()

        down_padding = 0
        if down_pad:
            down_padding = 1

        # inchannels, outchannels, kernel size
        self.conv1 = nn.Conv2d(3, channels, 3, stride=1, padding=down_padding)

        self.block1 = ResBlockStrided(channels, stride=2, down_padding=down_padding)
        self.block15 = ResBlock(channels)
        self.block2 = ResBlockStrided(channels, stride=2, down_padding=down_padding)
        self.block25 = ResBlock(channels)
        self.block3 = ResBlockStrided(channels, stride=1, down_padding=down_padding)
        self.block35 = ResBlock(channels)

        self.res_norm = nn.InstanceNorm2d(channels)

    def init_weights(self):
        self.block1.init_weights()
        self.block2.init_weights()
        self.block3.init_weights()

        self.block15.init_weights()
        self.block25.init_weights()
        self.block35.init_weights()

        torch.nn.init.kaiming_uniform(self.conv1.weight)
        self.conv1.bias.data.fill_(0)

    def forward(self, input):
        x = self.conv1(input)
        x = self.block1(x)
        x = self.block15(x)
        x = self.block2(x)
        x = self.block25(x)
        x = self.block3(x)
        x = self.block35(x)
        x = self.res_norm(x)
        return x
Exemple #7
0
class ResNet30(torch.nn.Module):
    def __init__(self, channels, down_pad=False):
        super(ResNet30, self).__init__()

        down_padding = 0
        if down_pad:
            down_padding = 1

        # inchannels, outchannels, kernel size
        self.conv1 = nn.Conv2d(3, channels, 3, stride=1, padding=down_padding)

        self.block1 = ResBlock(channels)
        self.block2 = ResBlock(channels)
        self.block3 = ResBlock(channels)
        self.block4 = ResBlockStrided(channels,
                                      stride=2,
                                      down_padding=down_padding)

        self.block5 = ResBlock(channels)
        self.block6 = ResBlock(channels)
        self.block7 = ResBlock(channels)
        self.block8 = ResBlockStrided(channels,
                                      stride=2,
                                      down_padding=down_padding)

        self.block9 = ResBlock(channels)
        self.block10 = ResBlock(channels)
        self.block11 = ResBlock(channels)
        self.block12 = ResBlockStrided(channels,
                                       stride=2,
                                       down_padding=down_padding)

        self.res_norm = nn.InstanceNorm2d(channels)

    def init_weights(self):
        self.block1.init_weights()
        self.block2.init_weights()
        self.block3.init_weights()
        self.block4.init_weights()
        self.block5.init_weights()
        self.block6.init_weights()
        self.block7.init_weights()
        self.block8.init_weights()
        self.block9.init_weights()
        self.block10.init_weights()
        self.block11.init_weights()
        self.block12.init_weights()

        torch.nn.init.kaiming_uniform(self.conv1.weight)
        self.conv1.bias.data.fill_(0)

    def forward(self, input):
        x = self.conv1(input)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        x = self.block12(x)
        x = self.res_norm(x)
        return x