示例#1
0
class ResNet13Light(nn.Module):
    def __init__(self, channels, down_pad=False):
        super(ResNet13Light, self).__init__()

        down_padding = 0
        if down_pad:
            down_padding = 1

        # inchannels, outchannels, kernel size
        self.conv1 = nn.Conv2d(3, channels, 3, stride=2, padding=down_padding)

        self.block1 = ResBlockStrided(channels,
                                      stride=2,
                                      down_padding=down_padding)
        self.block15 = ResBlock(channels)
        self.block2 = ResBlockStrided(channels,
                                      stride=2,
                                      down_padding=down_padding)
        self.block25 = ResBlock(channels)
        self.block3 = ResBlockStrided(channels,
                                      stride=1,
                                      down_padding=down_padding)
        self.block35 = ResBlock(channels)

        self.res_norm = nn.InstanceNorm2d(channels)

    def get_downscale_factor(self):
        return 8

    def init_weights(self):
        self.block1.init_weights()
        self.block2.init_weights()
        self.block3.init_weights()

        self.block15.init_weights()
        self.block25.init_weights()
        self.block35.init_weights()

        torch.nn.init.kaiming_uniform_(self.conv1.weight)
        self.conv1.bias.data.fill_(0)

    def forward(self, input):
        x = self.conv1(input)
        x = self.block1(x)
        x = self.block15(x)
        x = self.block2(x)
        x = self.block25(x)
        x = self.block3(x)
        x = self.block35(x)
        x = self.res_norm(x)
        return x
示例#2
0
    def __init__(self, channels, down_pad=False):
        super(ResNet13S, self).__init__()

        down_padding = 0
        if down_pad:
            down_padding = 1

        # inchannels, outchannels, kernel size
        self.conv1 = nn.Conv2d(3, channels, 3, stride=1, padding=down_padding)

        self.block1 = ResBlockStrided(channels, stride=2, down_padding=down_padding)
        self.block15 = ResBlock(channels)
        self.block2 = ResBlockStrided(channels, stride=2, down_padding=down_padding)
        self.block25 = ResBlock(channels)
        self.block3 = ResBlockStrided(channels, stride=1, down_padding=down_padding)
        self.block35 = ResBlock(channels)

        self.res_norm = nn.InstanceNorm2d(channels)
示例#3
0
文件: resnet_30.py 项目: pianpwk/drif
class ResNet30(torch.nn.Module):
    def __init__(self, channels, down_pad=False):
        super(ResNet30, self).__init__()

        down_padding = 0
        if down_pad:
            down_padding = 1

        # inchannels, outchannels, kernel size
        self.conv1 = nn.Conv2d(3, channels, 3, stride=1, padding=down_padding)

        self.block1 = ResBlock(channels)
        self.block2 = ResBlock(channels)
        self.block3 = ResBlock(channels)
        self.block4 = ResBlockStrided(channels,
                                      stride=2,
                                      down_padding=down_padding)

        self.block5 = ResBlock(channels)
        self.block6 = ResBlock(channels)
        self.block7 = ResBlock(channels)
        self.block8 = ResBlockStrided(channels,
                                      stride=2,
                                      down_padding=down_padding)

        self.block9 = ResBlock(channels)
        self.block10 = ResBlock(channels)
        self.block11 = ResBlock(channels)
        self.block12 = ResBlockStrided(channels,
                                       stride=2,
                                       down_padding=down_padding)

        self.res_norm = nn.InstanceNorm2d(channels)

    def init_weights(self):
        self.block1.init_weights()
        self.block2.init_weights()
        self.block3.init_weights()
        self.block4.init_weights()
        self.block5.init_weights()
        self.block6.init_weights()
        self.block7.init_weights()
        self.block8.init_weights()
        self.block9.init_weights()
        self.block10.init_weights()
        self.block11.init_weights()
        self.block12.init_weights()

        torch.nn.init.kaiming_uniform(self.conv1.weight)
        self.conv1.bias.data.fill_(0)

    def forward(self, input):
        x = self.conv1(input)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        x = self.block12(x)
        x = self.res_norm(x)
        return x
示例#4
0
    def __init__(self, embed_size, channels, c_out):
        super(ResNetConditional, self).__init__()

        self.block1 = ResBlock(channels)  # RF: 5x5
        self.block1a = ResBlock(channels)  # RF: 9x9
        self.cblock1 = ResBlockConditional(embed_size, channels)  # RF: 9x9
        self.block2 = ResBlock(channels)  # RF: 13x13
        self.block2a = ResBlock(channels)  # RF: 17x17
        self.cblock2 = ResBlockConditional(embed_size, channels)  # RF: 17x17
        self.block3 = ResBlock(channels)  # RF: 21x21
        self.block3a = ResBlock(channels)  # RF: 25x25
        self.cblock3 = ResBlockConditional(embed_size, channels)  # RF: 25x25
        self.block4 = ResBlock(channels)  # RF: 29x29
        self.block4a = ResBlock(channels)  # RF: 33x33
        self.cblock4 = ResBlockConditional(embed_size, channels)  # RF: 33x33
        self.block5 = ResBlock(channels)  # RF: 37x37
        self.block5a = ResBlock(channels)  # RF: 41x41
        self.cblock5 = ResBlockConditional(embed_size, channels)  # RF: 41x41
        self.block6 = ResBlock(channels)  # RF: 45x45
        self.block6a = ResBlock(channels)  # RF: 49x49
        self.cblock6 = ResBlockConditional(embed_size, channels)  # RF: 49x49
        self.block7 = ResBlock(channels)  # RF: 53x53
        self.block7a = ResBlock(channels)  # RF: 57x57
        self.cblock7 = ResBlockConditional(embed_size, channels)  # RF: 57x57
        self.block8 = ResBlock(channels)  # RF: 61x61
        self.block8a = ResBlock(channels)  # RF: 65x65
        self.cblock8 = ResBlockConditional(embed_size, channels,
                                           c_out)  # RF: 65x65