Esempio n. 1
0
class UNet(nn.Module):
    def __init__(self, in_channels: int = 1, base_channel_size: int = 64, bilinear=True, depth=4):
        super(UNet, self).__init__()
        self.name = 'UNet'
        self.n_channels = in_channels
        self.base_channel_size = base_channel_size
        self.bilinear = bilinear
        self.depth = depth

        self.inc = DoubleConv(in_channels, base_channel_size)
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        down_channel = base_channel_size
        factor = 2 if bilinear else 1
        # go down:
        # 64 -> 128 -> 256 -> 512 -> 1024
        for i in range(1, self.depth):
            self.downs.append(Down(down_channel, down_channel * 2))
            down_channel *= 2
        self.downs.append(Down(down_channel, down_channel * 2 // factor))
        for i in range(1, self.depth):
            self.ups.append(Up(down_channel * 2, down_channel // factor, bilinear))
            down_channel = down_channel // 2
        self.ups.append(Up(down_channel * 2, base_channel_size, bilinear))
        self.pad_to = PadToX(32)

    def forward(self, x):
        diffX, diffY, x, = self.pad_to(x)
        x = self.inc(x)
        intermediates = []
        for layer in self.downs:
            intermediates.append(x)
            x = layer(x)
        for layer, intermediate in zip(self.ups, intermediates[::-1]):
            x = layer(x, intermediate)
        x = self.pad_to.remove_pad(x, diffX, diffY)
        return x

    def initialize(self):

        def init_layer(layer):
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_normal_(layer.weight)
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, val=0)
            elif isinstance(layer, nn.BatchNorm2d):
                nn.init.constant_(layer.weight, 1)
                nn.init.constant_(layer.bias, 0)

        self.apply(init_layer)
Esempio n. 2
0
class Model(nn.Module):
    def __init__(self, backbone_name='ResUNet50_bc64', head_type='regression'):
        super(Model, self).__init__()

        self.name = f'ColorModel-{backbone_name}'
        self.backbone_name = backbone_name
        self.backbone = getattr(backbones_mod, self.backbone_name)()
        self.head_type = head_type

        def make_head(only_activation):
            if self.head_type.startswith('regression'):
                if only_activation:
                    return TanHActivation()
                return OutConv(self.backbone.base_channel_size, 2)

        self.head = make_head(
            only_activation=isinstance(self.backbone, SMPModel))
        self.up = None
        self.pad_to = PadToX(32)

    def forward(self, x):
        # pad to multiples of 32
        diffX, diffY, x, = self.pad_to(x)
        h, w = x.size()[2:]

        x = self.backbone(x)
        x = self.head(x)

        h_post, w_post = x.size()[2:]
        if h_post < h or w_post < w:
            if not self.up:
                scale_factor = h // h_post
                self.up = nn.Upsample(scale_factor=scale_factor,
                                      mode='bicubic',
                                      align_corners=True)
            x = self.up(x)

        x = self.pad_to.remove_pad(x, diffX, diffY)
        return x

    def initialize(self):
        print('Init backbone')
        self.backbone.initialize()

        # Initialize head
        print('Init head')
        self.head.initialize()

    def __repr__(self):
        return '\n'.join([
            f'     model: {self.name}',
            f'  backbone: {self.backbone_name}',
            f'      head: {self.head_type}',
        ])

    def save(self, state, iteration):
        checkpoint = {
            'backbone_name': self.backbone_name,
            'head_type': self.head_type,
            'state_dict': self.state_dict()
        }

        for key in ('epoch', 'optimizer', 'scheduler', 'iteration', 'scaler',
                    'sampler'):
            if key in state:
                checkpoint[key] = state[key]

        # get real concrete save path:
        concrete_path = build_model_file_name(state['path'], iteration)
        assert not os.path.isfile(concrete_path)

        torch.save(checkpoint, concrete_path)

    @classmethod
    def load(cls, filename):
        if not os.path.isfile(filename):
            raise ValueError('No checkpoint {}'.format(filename))

        checkpoint = torch.load(filename,
                                map_location=lambda storage, loc: storage)
        # Recreate model from checkpoint instead of from individual backbones
        model = cls(backbone_name=checkpoint['backbone_name'],
                    head_type=checkpoint['head_type'])
        model.load_state_dict(checkpoint['state_dict'])

        state = {}
        for key in ('epoch', 'optimizer', 'scheduler', 'iteration', 'scaler',
                    'sampler'):
            if key in checkpoint:
                state[key] = checkpoint[key]

        state['path'] = filename

        del checkpoint
        torch.cuda.empty_cache()

        return model, state
class ResnetPixShuffle(nn.Module):
    def __init__(self, features, bilinear: bool = True):
        super(ResnetPixShuffle, self).__init__()
        self.features = features
        self.name = 'ResnetPixShuffle'

        is_light = self.features.bottleneck == vrn.BasicBlock
        channels = [64, 64, 128, 256, 512
                    ] if is_light else [64, 256, 512, 1024, 2048]
        self.base_channel_size = channels[0]

        self.smooth5 = nn.Sequential(
            RDB(n_channels=channels[4],
                nDenselayer=3,
                growthRate=channels[4] // 4),
            RDB(n_channels=channels[4],
                nDenselayer=3,
                growthRate=channels[4] // 4))
        self.smooth4 = nn.Conv2d(channels[3], channels[3], kernel_size=1)
        self.smooth3 = nn.Conv2d(channels[2], channels[2], kernel_size=1)
        self.smooth2 = nn.Conv2d(channels[1], channels[1], kernel_size=1)
        self.smooth1 = nn.Conv2d(channels[0], channels[0], kernel_size=1)

        #                                                                       up + skip, out
        self.up1 = RDBPixShuffle(channels[4], channels[3])  # 2048, 1024 -> 512
        nChannels = channels[4] // 4 + channels[3]  # 1536
        self.up2 = RDBPixShuffle(nChannels, channels[2])  # 1536 + 512 -> 896
        nChannels = nChannels // 4 + channels[2]
        self.up3 = RDBPixShuffle(nChannels, channels[1])  # 896 + 256 -> 480
        nChannels = nChannels // 4 + channels[1]
        self.up4 = RDBPixShuffle(nChannels, channels[0])  # 480 + 64 -> 184
        nChannels = nChannels // 4 + channels[0]
        self.last_up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(nChannels, channels[0], kernel_size=3, padding=1))
        self.pad_to = PadToX(32)

    def forward(self, x):
        # need 3 channels for the pretrained resnet
        diffX, diffY, x, = self.pad_to(x)
        x = x.repeat(1, 3, 1, 1)
        c1, c2, c3, c4, c5 = self.features(x)

        c1 = self.smooth1(c1)
        c2 = self.smooth2(c2)
        c3 = self.smooth3(c3)
        c4 = self.smooth4(c4)
        c5 = self.smooth5(c5)

        x = self.up1(c5, c4)
        x = self.up2(x, c3)
        x = self.up3(x, c2)
        x = self.up4(x, c1)
        x = self.last_up(x)

        x = self.pad_to.remove_pad(x, diffX, diffY)
        return x

    def initialize(self):
        def init_layer(layer):
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, val=0)

        self.apply(init_layer)

        self.features.initialize()
Esempio n. 4
0
class DenseNetUNetRDB(nn.Module):
    def __init__(self, features):
        super(DenseNetUNetRDB, self).__init__()
        self.features = features
        self.name = 'DenseUnetRDB'

        channels = [64, 256, 512, 1024, 1024]
        self.base_channel_size = channels[0]

        self.smooth5 = nn.Sequential(
            nn.ReLU(), nn.Conv2d(channels[4], channels[4], kernel_size=1),
            nn.ReLU(),
            RDB(n_channels=channels[4],
                nDenselayer=3,
                growthRate=channels[4] // 4))
        self.smooth4 = nn.Sequential(
            nn.ReLU(), nn.Conv2d(channels[3], channels[3], kernel_size=1),
            nn.ReLU())
        self.smooth3 = nn.Sequential(
            nn.ReLU(), nn.Conv2d(channels[2], channels[2], kernel_size=1),
            nn.ReLU())
        self.smooth2 = nn.Sequential(
            nn.ReLU(), nn.Conv2d(channels[1], channels[1], kernel_size=1),
            nn.ReLU())
        self.smooth1 = nn.Sequential(
            nn.ReLU(), nn.Conv2d(channels[0], channels[0], kernel_size=1),
            nn.ReLU())

        #                                                                       up + skip, out
        self.up1 = RDBUp(channels[4], channels[3],
                         channels[3])  # 1024 + 1024, 1024
        self.up2 = RDBUp(channels[3], channels[2],
                         channels[2])  # 1024 + 512, 512
        self.up3 = RDBUp(channels[2], channels[1],
                         channels[1])  # 512 + 256, 256
        self.up4 = RDBUp(channels[1], channels[0], channels[0])  # 256 + 64, 64
        self.last_up = UpPad()
        self.conv_3x3 = nn.Conv2d(channels[0],
                                  channels[0],
                                  kernel_size=3,
                                  padding=1)
        self.pad_to = PadToX(32)

    def forward(self, x):
        # pad to multiples of 32
        diffX, diffY, x, = self.pad_to(x)

        orig = x
        # need 3 channels for the pretrained resnet
        x = x.repeat(1, 3, 1, 1)
        c1, c2, c3, c4, c5 = self.features(x)

        c1 = self.smooth1(c1)
        c2 = self.smooth2(c2)
        c3 = self.smooth3(c3)
        c4 = self.smooth4(c4)
        c5 = self.smooth5(c5)

        x = self.up1(c5, c4)
        x = self.up2(x, c3)
        x = self.up3(x, c2)
        x = self.up4(x, c1)
        # need the orig just for the size
        x = self.last_up(x, orig)
        x = self.conv_3x3(x)

        x = self.pad_to.remove_pad(x, diffX, diffY)

        return x

    def initialize(self):
        def init_layer(layer):
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, val=0)

        self.apply(init_layer)

        self.features.initialize()
Esempio n. 5
0
class UNetDualEncoder(nn.Module):
    def __init__(self, in_channels: int = 1, base_channel_size: int = 64, bilinear=True, depth=4,
                 second_encoder='vgg16'):
        super(UNetDualEncoder, self).__init__()
        self.name = 'UNet'
        self.n_channels = in_channels
        self.base_channel_size = base_channel_size
        self.bilinear = bilinear
        self.depth = depth
        self.second_encoder_name = second_encoder
        self.second_encoder = VGG(make_layers(cfgs['D'], False), init_weights=False)

        self.inc = DoubleConv(in_channels, base_channel_size)
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        self.fuse = nn.Sequential(nn.Conv2d(512 + 1024, 1024, kernel_size=1), nn.ReLU(inplace=True),
                                  DoubleConv(1024, 1024))
        self.pad_to = PadToX(32)
        down_channel = base_channel_size
        factor = 2 if bilinear else 1
        # go down:
        # 64 -> 128 -> 256 -> 512 -> 1024
        for i in range(1, self.depth):
            self.downs.append(Down(down_channel, down_channel * 2))
            down_channel *= 2
        self.downs.append(Down(down_channel, down_channel * 2 // factor))
        for i in range(1, self.depth):
            self.ups.append(Up(down_channel * 2, down_channel // factor, bilinear))
            down_channel = down_channel // 2
        self.ups.append(Up(down_channel * 2, base_channel_size, bilinear))

    def forward(self, x):
        diffX, diffY, x, = self.pad_to(x)
        x_exp = x.repeat(1, 3, 1, 1)
        extra_features = self.second_encoder(x_exp)
        x = self.inc(x)
        intermediates = []
        for layer in self.downs:
            intermediates.append(x)
            x = layer(x)
        x = torch.cat([x, extra_features], 1)
        x = self.fuse(x)
        for layer, intermediate in zip(self.ups, intermediates[::-1]):
            x = layer(x, intermediate)
        x = self.pad_to.remove_pad(x, diffX, diffY)
        return x

    def initialize(self):

        def init_layer(layer):
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain('relu'))
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, val=0)
            elif isinstance(layer, nn.BatchNorm2d):
                nn.init.constant_(layer.weight, 1)
                nn.init.constant_(layer.bias, 0)

        self.apply(init_layer)
        state_dict = load_state_dict_from_url(model_urls[self.second_encoder_name],
                                              progress=True)
        # https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3
        model_dict = self.second_encoder.state_dict()
        # 1. filter out unnecessary keys
        state_dict = OrderedDict({k: v for k, v in state_dict.items() if k in model_dict})
        # 2. overwrite entries in the existing state dict
        model_dict.update(state_dict)
        # 3. load the new state dict
        self.second_encoder.load_state_dict(state_dict)

        # Fix params in the second encoder
        for param in self.second_encoder.parameters():
            param.requires_grad = False
Esempio n. 6
0
class ResUNet(nn.Module):
    def __init__(self, features, bilinear: bool = True, v2=False):
        super(ResUNet, self).__init__()
        self.features = features
        self.name = 'ResUNet'
        self.v2 = v2

        is_light = self.features.bottleneck == vrn.BasicBlock
        channels = [64, 64, 128, 256, 512] if is_light else [64, 256, 512, 1024, 2048]
        self.base_channel_size = channels[0]
        factor = 2 if bilinear else 1

        self.top = nn.Sequential(
            nn.Conv2d(1, channels[0], 3, padding=1, bias=False),
            nn.BatchNorm2d(channels[0]),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels[0], channels[0], 3, padding=1, bias=False),
            nn.BatchNorm2d(channels[0]),
            nn.ReLU(inplace=True),
        )

        if v2:
            self.smooth4 = nn.Conv2d(channels[3], channels[3], kernel_size=1)
            self.smooth3 = nn.Conv2d(channels[2], channels[2], kernel_size=1)
            self.smooth2 = nn.Conv2d(channels[1], channels[1], kernel_size=1)
            self.smooth1 = nn.Conv2d(channels[0], channels[0], kernel_size=1)
            self.bottom = nn.Sequential(
                Bottleneck(channels[4], channels[4] // 4),
                Bottleneck(channels[4], channels[4] // 4),
                Bottleneck(channels[4], channels[4] // 4)
            )
        else:
            self.bottom = nn.Sequential(
                nn.Conv2d(channels[4], channels[4], 1),
                nn.ReLU(inplace=True),
            )

        if v2:
            Up = Upv2
        else:
            Up = Upv1
        #                                                                            up + skip, out
        self.up1 = Up(channels[4] + channels[3], channels[4] // factor, bilinear)  # 2048 + 1024, 1024
        self.up2 = Up(channels[3] + channels[2], channels[3] // factor, bilinear)  # 1024 + 512, 512
        self.up3 = Up(channels[2] + channels[1], channels[2] // factor, bilinear)  # 512 + 256, 256
        self.up4 = Up(channels[1] + channels[0], channels[1] // factor, bilinear)  # 256 + 64, 128
        self.up5 = Up(channels[1] // factor + channels[0], channels[0], bilinear)  # 256 + 64, 64
        self.last_up = nn.Sequential(
            nn.Conv2d(channels[1] // factor, channels[0], kernel_size=1),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        self.pad_to = PadToX(32)

    def forward(self, x):
        # need 3 channels for the pretrained resnet
        # top = self.top(x)
        diffX, diffY, x, = self.pad_to(x)
        x = x.repeat(1, 3, 1, 1)
        c1, c2, c3, c4, c5 = self.features(x)

        if self.v2:
            c1 = self.smooth1(c1)
            c2 = self.smooth2(c2)
            c3 = self.smooth3(c3)
            c4 = self.smooth4(c4)

        bottom = self.bottom(c5)
        x = self.up1(bottom, c4)
        x = self.up2(x, c3)
        x = self.up3(x, c2)
        x = self.up4(x, c1)
        # x = self.up5(x, top)
        x = self.last_up(x)

        x = self.pad_to.remove_pad(x, diffX, diffY)
        return x

    def initialize(self):
        def init_layer(layer):
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, val=0)

        self.apply(init_layer)

        self.features.initialize()