class ResNet13Light(nn.Module): def __init__(self, channels, down_pad=False): super(ResNet13Light, self).__init__() down_padding = 0 if down_pad: down_padding = 1 # inchannels, outchannels, kernel size self.conv1 = nn.Conv2d(3, channels, 3, stride=2, padding=down_padding) self.block1 = ResBlockStrided(channels, stride=2, down_padding=down_padding) self.block15 = ResBlock(channels) self.block2 = ResBlockStrided(channels, stride=2, down_padding=down_padding) self.block25 = ResBlock(channels) self.block3 = ResBlockStrided(channels, stride=1, down_padding=down_padding) self.block35 = ResBlock(channels) self.res_norm = nn.InstanceNorm2d(channels) def get_downscale_factor(self): return 8 def init_weights(self): self.block1.init_weights() self.block2.init_weights() self.block3.init_weights() self.block15.init_weights() self.block25.init_weights() self.block35.init_weights() torch.nn.init.kaiming_uniform_(self.conv1.weight) self.conv1.bias.data.fill_(0) def forward(self, input): x = self.conv1(input) x = self.block1(x) x = self.block15(x) x = self.block2(x) x = self.block25(x) x = self.block3(x) x = self.block35(x) x = self.res_norm(x) return x
def __init__(self, channels, down_pad=False): super(ResNet13S, self).__init__() down_padding = 0 if down_pad: down_padding = 1 # inchannels, outchannels, kernel size self.conv1 = nn.Conv2d(3, channels, 3, stride=1, padding=down_padding) self.block1 = ResBlockStrided(channels, stride=2, down_padding=down_padding) self.block15 = ResBlock(channels) self.block2 = ResBlockStrided(channels, stride=2, down_padding=down_padding) self.block25 = ResBlock(channels) self.block3 = ResBlockStrided(channels, stride=1, down_padding=down_padding) self.block35 = ResBlock(channels) self.res_norm = nn.InstanceNorm2d(channels)
class ResNet30(torch.nn.Module): def __init__(self, channels, down_pad=False): super(ResNet30, self).__init__() down_padding = 0 if down_pad: down_padding = 1 # inchannels, outchannels, kernel size self.conv1 = nn.Conv2d(3, channels, 3, stride=1, padding=down_padding) self.block1 = ResBlock(channels) self.block2 = ResBlock(channels) self.block3 = ResBlock(channels) self.block4 = ResBlockStrided(channels, stride=2, down_padding=down_padding) self.block5 = ResBlock(channels) self.block6 = ResBlock(channels) self.block7 = ResBlock(channels) self.block8 = ResBlockStrided(channels, stride=2, down_padding=down_padding) self.block9 = ResBlock(channels) self.block10 = ResBlock(channels) self.block11 = ResBlock(channels) self.block12 = ResBlockStrided(channels, stride=2, down_padding=down_padding) self.res_norm = nn.InstanceNorm2d(channels) def init_weights(self): self.block1.init_weights() self.block2.init_weights() self.block3.init_weights() self.block4.init_weights() self.block5.init_weights() self.block6.init_weights() self.block7.init_weights() self.block8.init_weights() self.block9.init_weights() self.block10.init_weights() self.block11.init_weights() self.block12.init_weights() torch.nn.init.kaiming_uniform(self.conv1.weight) self.conv1.bias.data.fill_(0) def forward(self, input): x = self.conv1(input) x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.block5(x) x = self.block6(x) x = self.block7(x) x = self.block8(x) x = self.block9(x) x = self.block10(x) x = self.block11(x) x = self.block12(x) x = self.res_norm(x) return x
def __init__(self, embed_size, channels, c_out): super(ResNetConditional, self).__init__() self.block1 = ResBlock(channels) # RF: 5x5 self.block1a = ResBlock(channels) # RF: 9x9 self.cblock1 = ResBlockConditional(embed_size, channels) # RF: 9x9 self.block2 = ResBlock(channels) # RF: 13x13 self.block2a = ResBlock(channels) # RF: 17x17 self.cblock2 = ResBlockConditional(embed_size, channels) # RF: 17x17 self.block3 = ResBlock(channels) # RF: 21x21 self.block3a = ResBlock(channels) # RF: 25x25 self.cblock3 = ResBlockConditional(embed_size, channels) # RF: 25x25 self.block4 = ResBlock(channels) # RF: 29x29 self.block4a = ResBlock(channels) # RF: 33x33 self.cblock4 = ResBlockConditional(embed_size, channels) # RF: 33x33 self.block5 = ResBlock(channels) # RF: 37x37 self.block5a = ResBlock(channels) # RF: 41x41 self.cblock5 = ResBlockConditional(embed_size, channels) # RF: 41x41 self.block6 = ResBlock(channels) # RF: 45x45 self.block6a = ResBlock(channels) # RF: 49x49 self.cblock6 = ResBlockConditional(embed_size, channels) # RF: 49x49 self.block7 = ResBlock(channels) # RF: 53x53 self.block7a = ResBlock(channels) # RF: 57x57 self.cblock7 = ResBlockConditional(embed_size, channels) # RF: 57x57 self.block8 = ResBlock(channels) # RF: 61x61 self.block8a = ResBlock(channels) # RF: 65x65 self.cblock8 = ResBlockConditional(embed_size, channels, c_out) # RF: 65x65