def __init__(self, in_channels: int = 1, base_channel_size: int = 64, bilinear=True, depth=4, second_encoder='vgg16'): super(UNetDualEncoder, self).__init__() self.name = 'UNet' self.n_channels = in_channels self.base_channel_size = base_channel_size self.bilinear = bilinear self.depth = depth self.second_encoder_name = second_encoder self.second_encoder = VGG(make_layers(cfgs['D'], False), init_weights=False) self.inc = DoubleConv(in_channels, base_channel_size) self.downs = nn.ModuleList() self.ups = nn.ModuleList() self.fuse = nn.Sequential(nn.Conv2d(512 + 1024, 1024, kernel_size=1), nn.ReLU(inplace=True), DoubleConv(1024, 1024)) self.pad_to = PadToX(32) down_channel = base_channel_size factor = 2 if bilinear else 1 # go down: # 64 -> 128 -> 256 -> 512 -> 1024 for i in range(1, self.depth): self.downs.append(Down(down_channel, down_channel * 2)) down_channel *= 2 self.downs.append(Down(down_channel, down_channel * 2 // factor)) for i in range(1, self.depth): self.ups.append(Up(down_channel * 2, down_channel // factor, bilinear)) down_channel = down_channel // 2 self.ups.append(Up(down_channel * 2, base_channel_size, bilinear))
def __init__(self, features, bilinear: bool = True): super(ResnetPixShuffle, self).__init__() self.features = features self.name = 'ResnetPixShuffle' is_light = self.features.bottleneck == vrn.BasicBlock channels = [64, 64, 128, 256, 512 ] if is_light else [64, 256, 512, 1024, 2048] self.base_channel_size = channels[0] self.smooth5 = nn.Sequential( RDB(n_channels=channels[4], nDenselayer=3, growthRate=channels[4] // 4), RDB(n_channels=channels[4], nDenselayer=3, growthRate=channels[4] // 4)) self.smooth4 = nn.Conv2d(channels[3], channels[3], kernel_size=1) self.smooth3 = nn.Conv2d(channels[2], channels[2], kernel_size=1) self.smooth2 = nn.Conv2d(channels[1], channels[1], kernel_size=1) self.smooth1 = nn.Conv2d(channels[0], channels[0], kernel_size=1) # up + skip, out self.up1 = RDBPixShuffle(channels[4], channels[3]) # 2048, 1024 -> 512 nChannels = channels[4] // 4 + channels[3] # 1536 self.up2 = RDBPixShuffle(nChannels, channels[2]) # 1536 + 512 -> 896 nChannels = nChannels // 4 + channels[2] self.up3 = RDBPixShuffle(nChannels, channels[1]) # 896 + 256 -> 480 nChannels = nChannels // 4 + channels[1] self.up4 = RDBPixShuffle(nChannels, channels[0]) # 480 + 64 -> 184 nChannels = nChannels // 4 + channels[0] self.last_up = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(nChannels, channels[0], kernel_size=3, padding=1)) self.pad_to = PadToX(32)
def __init__(self, features, bilinear: bool = True, v2=False): super(ResUNet, self).__init__() self.features = features self.name = 'ResUNet' self.v2 = v2 is_light = self.features.bottleneck == vrn.BasicBlock channels = [64, 64, 128, 256, 512] if is_light else [64, 256, 512, 1024, 2048] self.base_channel_size = channels[0] factor = 2 if bilinear else 1 self.top = nn.Sequential( nn.Conv2d(1, channels[0], 3, padding=1, bias=False), nn.BatchNorm2d(channels[0]), nn.ReLU(inplace=True), nn.Conv2d(channels[0], channels[0], 3, padding=1, bias=False), nn.BatchNorm2d(channels[0]), nn.ReLU(inplace=True), ) if v2: self.smooth4 = nn.Conv2d(channels[3], channels[3], kernel_size=1) self.smooth3 = nn.Conv2d(channels[2], channels[2], kernel_size=1) self.smooth2 = nn.Conv2d(channels[1], channels[1], kernel_size=1) self.smooth1 = nn.Conv2d(channels[0], channels[0], kernel_size=1) self.bottom = nn.Sequential( Bottleneck(channels[4], channels[4] // 4), Bottleneck(channels[4], channels[4] // 4), Bottleneck(channels[4], channels[4] // 4) ) else: self.bottom = nn.Sequential( nn.Conv2d(channels[4], channels[4], 1), nn.ReLU(inplace=True), ) if v2: Up = Upv2 else: Up = Upv1 # up + skip, out self.up1 = Up(channels[4] + channels[3], channels[4] // factor, bilinear) # 2048 + 1024, 1024 self.up2 = Up(channels[3] + channels[2], channels[3] // factor, bilinear) # 1024 + 512, 512 self.up3 = Up(channels[2] + channels[1], channels[2] // factor, bilinear) # 512 + 256, 256 self.up4 = Up(channels[1] + channels[0], channels[1] // factor, bilinear) # 256 + 64, 128 self.up5 = Up(channels[1] // factor + channels[0], channels[0], bilinear) # 256 + 64, 64 self.last_up = nn.Sequential( nn.Conv2d(channels[1] // factor, channels[0], kernel_size=1), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) ) self.pad_to = PadToX(32)
def __init__(self, backbone_name='ResUNet50_bc64', head_type='regression'): super(Model, self).__init__() self.name = f'ColorModel-{backbone_name}' self.backbone_name = backbone_name self.backbone = getattr(backbones_mod, self.backbone_name)() self.head_type = head_type def make_head(only_activation): if self.head_type.startswith('regression'): if only_activation: return TanHActivation() return OutConv(self.backbone.base_channel_size, 2) self.head = make_head( only_activation=isinstance(self.backbone, SMPModel)) self.up = None self.pad_to = PadToX(32)
def __init__(self, features): super(DenseNetUNetRDB, self).__init__() self.features = features self.name = 'DenseUnetRDB' channels = [64, 256, 512, 1024, 1024] self.base_channel_size = channels[0] self.smooth5 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[4], channels[4], kernel_size=1), nn.ReLU(), RDB(n_channels=channels[4], nDenselayer=3, growthRate=channels[4] // 4)) self.smooth4 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[3], channels[3], kernel_size=1), nn.ReLU()) self.smooth3 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[2], channels[2], kernel_size=1), nn.ReLU()) self.smooth2 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[1], channels[1], kernel_size=1), nn.ReLU()) self.smooth1 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[0], channels[0], kernel_size=1), nn.ReLU()) # up + skip, out self.up1 = RDBUp(channels[4], channels[3], channels[3]) # 1024 + 1024, 1024 self.up2 = RDBUp(channels[3], channels[2], channels[2]) # 1024 + 512, 512 self.up3 = RDBUp(channels[2], channels[1], channels[1]) # 512 + 256, 256 self.up4 = RDBUp(channels[1], channels[0], channels[0]) # 256 + 64, 64 self.last_up = UpPad() self.conv_3x3 = nn.Conv2d(channels[0], channels[0], kernel_size=3, padding=1) self.pad_to = PadToX(32)
class UNet(nn.Module): def __init__(self, in_channels: int = 1, base_channel_size: int = 64, bilinear=True, depth=4): super(UNet, self).__init__() self.name = 'UNet' self.n_channels = in_channels self.base_channel_size = base_channel_size self.bilinear = bilinear self.depth = depth self.inc = DoubleConv(in_channels, base_channel_size) self.downs = nn.ModuleList() self.ups = nn.ModuleList() down_channel = base_channel_size factor = 2 if bilinear else 1 # go down: # 64 -> 128 -> 256 -> 512 -> 1024 for i in range(1, self.depth): self.downs.append(Down(down_channel, down_channel * 2)) down_channel *= 2 self.downs.append(Down(down_channel, down_channel * 2 // factor)) for i in range(1, self.depth): self.ups.append(Up(down_channel * 2, down_channel // factor, bilinear)) down_channel = down_channel // 2 self.ups.append(Up(down_channel * 2, base_channel_size, bilinear)) self.pad_to = PadToX(32) def forward(self, x): diffX, diffY, x, = self.pad_to(x) x = self.inc(x) intermediates = [] for layer in self.downs: intermediates.append(x) x = layer(x) for layer, intermediate in zip(self.ups, intermediates[::-1]): x = layer(x, intermediate) x = self.pad_to.remove_pad(x, diffX, diffY) return x def initialize(self): def init_layer(layer): if isinstance(layer, nn.Conv2d): nn.init.xavier_normal_(layer.weight) if layer.bias is not None: nn.init.constant_(layer.bias, val=0) elif isinstance(layer, nn.BatchNorm2d): nn.init.constant_(layer.weight, 1) nn.init.constant_(layer.bias, 0) self.apply(init_layer)
def __init__(self, in_channels: int = 1, base_channel_size: int = 64, bilinear=True, depth=4): super(UNet, self).__init__() self.name = 'UNet' self.n_channels = in_channels self.base_channel_size = base_channel_size self.bilinear = bilinear self.depth = depth self.inc = DoubleConv(in_channels, base_channel_size) self.downs = nn.ModuleList() self.ups = nn.ModuleList() down_channel = base_channel_size factor = 2 if bilinear else 1 # go down: # 64 -> 128 -> 256 -> 512 -> 1024 for i in range(1, self.depth): self.downs.append(Down(down_channel, down_channel * 2)) down_channel *= 2 self.downs.append(Down(down_channel, down_channel * 2 // factor)) for i in range(1, self.depth): self.ups.append(Up(down_channel * 2, down_channel // factor, bilinear)) down_channel = down_channel // 2 self.ups.append(Up(down_channel * 2, base_channel_size, bilinear)) self.pad_to = PadToX(32)
class Model(nn.Module): def __init__(self, backbone_name='ResUNet50_bc64', head_type='regression'): super(Model, self).__init__() self.name = f'ColorModel-{backbone_name}' self.backbone_name = backbone_name self.backbone = getattr(backbones_mod, self.backbone_name)() self.head_type = head_type def make_head(only_activation): if self.head_type.startswith('regression'): if only_activation: return TanHActivation() return OutConv(self.backbone.base_channel_size, 2) self.head = make_head( only_activation=isinstance(self.backbone, SMPModel)) self.up = None self.pad_to = PadToX(32) def forward(self, x): # pad to multiples of 32 diffX, diffY, x, = self.pad_to(x) h, w = x.size()[2:] x = self.backbone(x) x = self.head(x) h_post, w_post = x.size()[2:] if h_post < h or w_post < w: if not self.up: scale_factor = h // h_post self.up = nn.Upsample(scale_factor=scale_factor, mode='bicubic', align_corners=True) x = self.up(x) x = self.pad_to.remove_pad(x, diffX, diffY) return x def initialize(self): print('Init backbone') self.backbone.initialize() # Initialize head print('Init head') self.head.initialize() def __repr__(self): return '\n'.join([ f' model: {self.name}', f' backbone: {self.backbone_name}', f' head: {self.head_type}', ]) def save(self, state, iteration): checkpoint = { 'backbone_name': self.backbone_name, 'head_type': self.head_type, 'state_dict': self.state_dict() } for key in ('epoch', 'optimizer', 'scheduler', 'iteration', 'scaler', 'sampler'): if key in state: checkpoint[key] = state[key] # get real concrete save path: concrete_path = build_model_file_name(state['path'], iteration) assert not os.path.isfile(concrete_path) torch.save(checkpoint, concrete_path) @classmethod def load(cls, filename): if not os.path.isfile(filename): raise ValueError('No checkpoint {}'.format(filename)) checkpoint = torch.load(filename, map_location=lambda storage, loc: storage) # Recreate model from checkpoint instead of from individual backbones model = cls(backbone_name=checkpoint['backbone_name'], head_type=checkpoint['head_type']) model.load_state_dict(checkpoint['state_dict']) state = {} for key in ('epoch', 'optimizer', 'scheduler', 'iteration', 'scaler', 'sampler'): if key in checkpoint: state[key] = checkpoint[key] state['path'] = filename del checkpoint torch.cuda.empty_cache() return model, state
class ResnetPixShuffle(nn.Module): def __init__(self, features, bilinear: bool = True): super(ResnetPixShuffle, self).__init__() self.features = features self.name = 'ResnetPixShuffle' is_light = self.features.bottleneck == vrn.BasicBlock channels = [64, 64, 128, 256, 512 ] if is_light else [64, 256, 512, 1024, 2048] self.base_channel_size = channels[0] self.smooth5 = nn.Sequential( RDB(n_channels=channels[4], nDenselayer=3, growthRate=channels[4] // 4), RDB(n_channels=channels[4], nDenselayer=3, growthRate=channels[4] // 4)) self.smooth4 = nn.Conv2d(channels[3], channels[3], kernel_size=1) self.smooth3 = nn.Conv2d(channels[2], channels[2], kernel_size=1) self.smooth2 = nn.Conv2d(channels[1], channels[1], kernel_size=1) self.smooth1 = nn.Conv2d(channels[0], channels[0], kernel_size=1) # up + skip, out self.up1 = RDBPixShuffle(channels[4], channels[3]) # 2048, 1024 -> 512 nChannels = channels[4] // 4 + channels[3] # 1536 self.up2 = RDBPixShuffle(nChannels, channels[2]) # 1536 + 512 -> 896 nChannels = nChannels // 4 + channels[2] self.up3 = RDBPixShuffle(nChannels, channels[1]) # 896 + 256 -> 480 nChannels = nChannels // 4 + channels[1] self.up4 = RDBPixShuffle(nChannels, channels[0]) # 480 + 64 -> 184 nChannels = nChannels // 4 + channels[0] self.last_up = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), nn.Conv2d(nChannels, channels[0], kernel_size=3, padding=1)) self.pad_to = PadToX(32) def forward(self, x): # need 3 channels for the pretrained resnet diffX, diffY, x, = self.pad_to(x) x = x.repeat(1, 3, 1, 1) c1, c2, c3, c4, c5 = self.features(x) c1 = self.smooth1(c1) c2 = self.smooth2(c2) c3 = self.smooth3(c3) c4 = self.smooth4(c4) c5 = self.smooth5(c5) x = self.up1(c5, c4) x = self.up2(x, c3) x = self.up3(x, c2) x = self.up4(x, c1) x = self.last_up(x) x = self.pad_to.remove_pad(x, diffX, diffY) return x def initialize(self): def init_layer(layer): if isinstance(layer, nn.Conv2d): nn.init.xavier_uniform_(layer.weight) if layer.bias is not None: nn.init.constant_(layer.bias, val=0) self.apply(init_layer) self.features.initialize()
class DenseNetUNetRDB(nn.Module): def __init__(self, features): super(DenseNetUNetRDB, self).__init__() self.features = features self.name = 'DenseUnetRDB' channels = [64, 256, 512, 1024, 1024] self.base_channel_size = channels[0] self.smooth5 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[4], channels[4], kernel_size=1), nn.ReLU(), RDB(n_channels=channels[4], nDenselayer=3, growthRate=channels[4] // 4)) self.smooth4 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[3], channels[3], kernel_size=1), nn.ReLU()) self.smooth3 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[2], channels[2], kernel_size=1), nn.ReLU()) self.smooth2 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[1], channels[1], kernel_size=1), nn.ReLU()) self.smooth1 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[0], channels[0], kernel_size=1), nn.ReLU()) # up + skip, out self.up1 = RDBUp(channels[4], channels[3], channels[3]) # 1024 + 1024, 1024 self.up2 = RDBUp(channels[3], channels[2], channels[2]) # 1024 + 512, 512 self.up3 = RDBUp(channels[2], channels[1], channels[1]) # 512 + 256, 256 self.up4 = RDBUp(channels[1], channels[0], channels[0]) # 256 + 64, 64 self.last_up = UpPad() self.conv_3x3 = nn.Conv2d(channels[0], channels[0], kernel_size=3, padding=1) self.pad_to = PadToX(32) def forward(self, x): # pad to multiples of 32 diffX, diffY, x, = self.pad_to(x) orig = x # need 3 channels for the pretrained resnet x = x.repeat(1, 3, 1, 1) c1, c2, c3, c4, c5 = self.features(x) c1 = self.smooth1(c1) c2 = self.smooth2(c2) c3 = self.smooth3(c3) c4 = self.smooth4(c4) c5 = self.smooth5(c5) x = self.up1(c5, c4) x = self.up2(x, c3) x = self.up3(x, c2) x = self.up4(x, c1) # need the orig just for the size x = self.last_up(x, orig) x = self.conv_3x3(x) x = self.pad_to.remove_pad(x, diffX, diffY) return x def initialize(self): def init_layer(layer): if isinstance(layer, nn.Conv2d): nn.init.xavier_uniform_(layer.weight) if layer.bias is not None: nn.init.constant_(layer.bias, val=0) self.apply(init_layer) self.features.initialize()
def __init__(self, features, global_features, v2=False): super(DualDenseNetUNetRDB, self).__init__() self.features = features self.global_features = global_features self.name = 'DualDenseUnetRDB' self.v2 = v2 channels = [64, 256, 512, 1024, 1024] self.base_channel_size = channels[0] if not v2: self.smooth5 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[4] * 2, channels[4], kernel_size=1), nn.ReLU(), RDB(n_channels=channels[4], nDenselayer=3, growthRate=channels[4] // 4)) self.smooth4 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[3], channels[3], kernel_size=1), nn.ReLU()) self.smooth3 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[2], channels[2], kernel_size=1), nn.ReLU()) self.smooth2 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[1], channels[1], kernel_size=1), nn.ReLU()) self.smooth1 = nn.Sequential( nn.ReLU(), nn.Conv2d(channels[0], channels[0], kernel_size=1), nn.ReLU()) self.up1 = RDBUp(channels[4], channels[3], channels[3]) # 1024 + 1024, 1024 self.up2 = RDBUp(channels[3], channels[2], channels[2]) # 1024 + 512, 512 self.up3 = RDBUp(channels[2], channels[1], channels[1]) # 512 + 256, 256 self.up4 = RDBUp(channels[1], channels[0], channels[0]) # 256 + 64, 64 self.last_up = UpPad() self.conv_3x3 = nn.Conv2d(channels[0], channels[0], kernel_size=3, padding=1) else: self.down5 = nn.Sequential( _Transition(channels[4] * 2, channels[4]), _DenseBlock(num_layers=16, num_input_features=channels[4], bn_size=4, growth_rate=32, drop_rate=0, memory_efficient=False)) self.up0 = ConvUp(channels[4] + 512, channels[4], channels[4]) self.smooth4 = nn.Sequential(nn.BatchNorm2d(channels[3]), nn.LeakyReLU()) self.smooth3 = nn.Sequential(nn.BatchNorm2d(channels[2]), nn.LeakyReLU()) self.smooth2 = nn.Sequential(nn.BatchNorm2d(channels[1]), nn.LeakyReLU()) self.smooth1 = nn.Sequential(nn.BatchNorm2d(channels[0]), nn.LeakyReLU()) # up + skip, out self.up1 = ConvUp(channels[4], channels[3], channels[3]) # 1024 + 1024, 1024 self.up2 = ConvUp(channels[3], channels[2], channels[2]) # 1024 + 512, 512 self.up3 = ConvUp(channels[2], channels[1], channels[1]) # 512 + 256, 256 self.up4 = ConvUp(channels[1], channels[0], channels[0]) # 256 + 64, 64 self.pad_to = PadToX(32)
class UNetDualEncoder(nn.Module): def __init__(self, in_channels: int = 1, base_channel_size: int = 64, bilinear=True, depth=4, second_encoder='vgg16'): super(UNetDualEncoder, self).__init__() self.name = 'UNet' self.n_channels = in_channels self.base_channel_size = base_channel_size self.bilinear = bilinear self.depth = depth self.second_encoder_name = second_encoder self.second_encoder = VGG(make_layers(cfgs['D'], False), init_weights=False) self.inc = DoubleConv(in_channels, base_channel_size) self.downs = nn.ModuleList() self.ups = nn.ModuleList() self.fuse = nn.Sequential(nn.Conv2d(512 + 1024, 1024, kernel_size=1), nn.ReLU(inplace=True), DoubleConv(1024, 1024)) self.pad_to = PadToX(32) down_channel = base_channel_size factor = 2 if bilinear else 1 # go down: # 64 -> 128 -> 256 -> 512 -> 1024 for i in range(1, self.depth): self.downs.append(Down(down_channel, down_channel * 2)) down_channel *= 2 self.downs.append(Down(down_channel, down_channel * 2 // factor)) for i in range(1, self.depth): self.ups.append(Up(down_channel * 2, down_channel // factor, bilinear)) down_channel = down_channel // 2 self.ups.append(Up(down_channel * 2, base_channel_size, bilinear)) def forward(self, x): diffX, diffY, x, = self.pad_to(x) x_exp = x.repeat(1, 3, 1, 1) extra_features = self.second_encoder(x_exp) x = self.inc(x) intermediates = [] for layer in self.downs: intermediates.append(x) x = layer(x) x = torch.cat([x, extra_features], 1) x = self.fuse(x) for layer, intermediate in zip(self.ups, intermediates[::-1]): x = layer(x, intermediate) x = self.pad_to.remove_pad(x, diffX, diffY) return x def initialize(self): def init_layer(layer): if isinstance(layer, nn.Conv2d): nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain('relu')) if layer.bias is not None: nn.init.constant_(layer.bias, val=0) elif isinstance(layer, nn.BatchNorm2d): nn.init.constant_(layer.weight, 1) nn.init.constant_(layer.bias, 0) self.apply(init_layer) state_dict = load_state_dict_from_url(model_urls[self.second_encoder_name], progress=True) # https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3 model_dict = self.second_encoder.state_dict() # 1. filter out unnecessary keys state_dict = OrderedDict({k: v for k, v in state_dict.items() if k in model_dict}) # 2. overwrite entries in the existing state dict model_dict.update(state_dict) # 3. load the new state dict self.second_encoder.load_state_dict(state_dict) # Fix params in the second encoder for param in self.second_encoder.parameters(): param.requires_grad = False
class ResUNet(nn.Module): def __init__(self, features, bilinear: bool = True, v2=False): super(ResUNet, self).__init__() self.features = features self.name = 'ResUNet' self.v2 = v2 is_light = self.features.bottleneck == vrn.BasicBlock channels = [64, 64, 128, 256, 512] if is_light else [64, 256, 512, 1024, 2048] self.base_channel_size = channels[0] factor = 2 if bilinear else 1 self.top = nn.Sequential( nn.Conv2d(1, channels[0], 3, padding=1, bias=False), nn.BatchNorm2d(channels[0]), nn.ReLU(inplace=True), nn.Conv2d(channels[0], channels[0], 3, padding=1, bias=False), nn.BatchNorm2d(channels[0]), nn.ReLU(inplace=True), ) if v2: self.smooth4 = nn.Conv2d(channels[3], channels[3], kernel_size=1) self.smooth3 = nn.Conv2d(channels[2], channels[2], kernel_size=1) self.smooth2 = nn.Conv2d(channels[1], channels[1], kernel_size=1) self.smooth1 = nn.Conv2d(channels[0], channels[0], kernel_size=1) self.bottom = nn.Sequential( Bottleneck(channels[4], channels[4] // 4), Bottleneck(channels[4], channels[4] // 4), Bottleneck(channels[4], channels[4] // 4) ) else: self.bottom = nn.Sequential( nn.Conv2d(channels[4], channels[4], 1), nn.ReLU(inplace=True), ) if v2: Up = Upv2 else: Up = Upv1 # up + skip, out self.up1 = Up(channels[4] + channels[3], channels[4] // factor, bilinear) # 2048 + 1024, 1024 self.up2 = Up(channels[3] + channels[2], channels[3] // factor, bilinear) # 1024 + 512, 512 self.up3 = Up(channels[2] + channels[1], channels[2] // factor, bilinear) # 512 + 256, 256 self.up4 = Up(channels[1] + channels[0], channels[1] // factor, bilinear) # 256 + 64, 128 self.up5 = Up(channels[1] // factor + channels[0], channels[0], bilinear) # 256 + 64, 64 self.last_up = nn.Sequential( nn.Conv2d(channels[1] // factor, channels[0], kernel_size=1), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) ) self.pad_to = PadToX(32) def forward(self, x): # need 3 channels for the pretrained resnet # top = self.top(x) diffX, diffY, x, = self.pad_to(x) x = x.repeat(1, 3, 1, 1) c1, c2, c3, c4, c5 = self.features(x) if self.v2: c1 = self.smooth1(c1) c2 = self.smooth2(c2) c3 = self.smooth3(c3) c4 = self.smooth4(c4) bottom = self.bottom(c5) x = self.up1(bottom, c4) x = self.up2(x, c3) x = self.up3(x, c2) x = self.up4(x, c1) # x = self.up5(x, top) x = self.last_up(x) x = self.pad_to.remove_pad(x, diffX, diffY) return x def initialize(self): def init_layer(layer): if isinstance(layer, nn.Conv2d): nn.init.xavier_uniform_(layer.weight) if layer.bias is not None: nn.init.constant_(layer.bias, val=0) self.apply(init_layer) self.features.initialize()