Esempio n. 1
0
    def __init__(self,
                 encoder: nn.Module,
                 n_classes: int,
                 blur: bool = False,
                 blur_final=True,
                 self_attention: bool = False,
                 y_range: Optional[Tuple[float, float]] = None,
                 last_cross: bool = True,
                 bottle: bool = False,
                 **kwargs):
        imsize = (args.size, args.size)
        sfs_szs = model_sizes(encoder, size=imsize)
        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
        x = dummy_eval(encoder, imsize).detach()

        ni = sfs_szs[-1][1]
        middle_conv = nn.Sequential(conv_layer(ni, ni * 2, **kwargs),
                                    conv_layer(ni * 2, ni, **kwargs)).eval()
        x = middle_conv(x)
        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]

        self.hc_hooks = [Hook(layers[-1], _hook_inner, detach=False)]
        hc_c = [x.shape[1]]

        for i, idx in enumerate(sfs_idxs):
            not_final = i != len(sfs_idxs) - 1
            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
            do_blur = blur and (not_final or blur_final)
            sa = self_attention and (i == len(sfs_idxs) - 3)
            unet_block = UnetBlock(up_in_c,
                                   x_in_c,
                                   self.sfs[i],
                                   final_div=not_final,
                                   blur=blur,
                                   self_attention=sa,
                                   **kwargs).eval()
            layers.append(unet_block)
            x = unet_block(x)
            self.hc_hooks.append(Hook(layers[-1], _hook_inner, detach=False))
            hc_c.append(x.shape[1])

        ni = x.shape[1]
        if imsize != sfs_szs[0][-2:]:
            layers.append(PixelShuffle_ICNR(ni, **kwargs))
        if last_cross:
            layers.append(MergeLayer(dense=True))
            ni += in_channels(encoder)
            layers.append(res_block(ni, bottle=bottle, **kwargs))
        hc_c.append(ni)
        layers.append(Hcolumns(self.hc_hooks, hc_c))
        layers += [
            conv_layer(ni * len(hc_c),
                       n_classes,
                       ks=1,
                       use_activ=False,
                       **kwargs)
        ]
        if y_range is not None: layers.append(SigmoidRange(*y_range))
        super().__init__(*layers)
Esempio n. 2
0
    def __init__(self, encoder, n_classes, final_bias=0., chs=256, n_anchors=9, flatten=True, chip_size=(256,256), n_bands=3):

        # chs - channels for top down layers in FPN
        
        super().__init__()
        self.n_classes,self.flatten = n_classes,flatten
        self.chip_size = chip_size
        
        
        # Fetch the sizes of various activation layers of the backbone
        sfs_szs = model_sizes(encoder, size=self.chip_size)

        hooks = hook_outputs(encoder)

        self.encoder = encoder
        self.c5top5 = conv2d(sfs_szs[-1][1], chs, ks=1, bias=True)
        self.c5top6 = conv2d(sfs_szs[-1][1], chs, stride=2, bias=True)
        self.p6top7 = nn.Sequential(nn.ReLU(), conv2d(chs, chs, stride=2, bias=True))
        self.merges = nn.ModuleList([LateralUpsampleMerge(chs, szs[1], hook) 
                                        for szs,hook in zip(sfs_szs[-2:-4:-1], hooks[-2:-4:-1])])
        self.smoothers = nn.ModuleList([conv2d(chs, chs, 3, bias=True) for _ in range(3)])
        self.classifier = self._head_subnet(n_classes, n_anchors, final_bias, chs=chs)
        self.box_regressor = self._head_subnet(4, n_anchors, 0., chs=chs)

        # Create a dummy x to be passed through the model and fetch the sizes
        x_dummy = torch.rand(n_bands,self.chip_size[0],self.chip_size[1]).unsqueeze(0)
        p_states = self._create_p_states(x_dummy)
        self.sizes = [[p.size(2), p.size(3)] for p in p_states]
Esempio n. 3
0
    def __init__(self, num_classes, backbone_fn, chip_size=224):
        super().__init__()
        if getattr(backbone_fn, '_is_multispectral', False):
            self.backbone = create_body(backbone_fn,
                                        pretrained=True,
                                        cut=_get_backbone_meta(
                                            backbone_fn.__name__)['cut'])
        else:
            self.backbone = create_body(backbone_fn, pretrained=True)

        backbone_name = backbone_fn.__name__

        ## Support for different backbones
        if "densenet" in backbone_name or "vgg" in backbone_name:
            hookable_modules = list(self.backbone.children())[0]
        else:
            hookable_modules = list(self.backbone.children())

        if "vgg" in backbone_name:
            modify_dilation_index = -5
        else:
            modify_dilation_index = -2

        if backbone_name == 'resnet18' or backbone_name == 'resnet34':
            module_to_check = 'conv'
        else:
            module_to_check = 'conv2'

        ## Hook at the index where we need to get the auxillary logits out
        self.hook = hook_output(hookable_modules[modify_dilation_index])

        custom_idx = 0
        for i, module in enumerate(hookable_modules[modify_dilation_index:]):
            dilation = 2 * (i + 1)
            padding = 2 * (i + 1)
            for n, m in module.named_modules():
                if module_to_check in n:
                    m.dilation, m.padding, m.stride = (dilation, dilation), (
                        padding, padding), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)

            if "vgg" in backbone_fn.__name__:
                if isinstance(module, nn.Conv2d):
                    dilation = 2 * (custom_idx + 1)
                    padding = 2 * (custom_idx + 1)
                    module.dilation, module.padding, module.stride = (
                        dilation, dilation), (padding, padding), (1, 1)
                    custom_idx += 1

        ## returns the size of various activations
        feature_sizes = model_sizes(self.backbone, size=(chip_size, chip_size))
        ## Geting the number of channel persent in stored activation inside of the hook
        num_channels_aux_classifier = self.hook.stored.shape[1]
        ## Get number of channels in the last layer
        num_channels_classifier = feature_sizes[-1][1]

        self.classifier = DeepLabHead(num_channels_classifier, num_classes)
        self.aux_classifier = FCNHead(num_channels_aux_classifier, num_classes)
    def __init__(self, model, chip_size, num_classes):
        super(AuxPSUnet, self).__init__()      
        self.model = model

        for idx, i in enumerate(flatten_model(self.model)):
            if hasattr(i, 'dilation'):
                dilation = i.dilation
                dilation = dilation[0] if isinstance(dilation, tuple) else dilation
                if dilation > 1:
                    break   

        self.hook = hook_output(flatten_model(model)[idx - 1])

        ## returns the size of various activations
        model_sizes(self.model, size=(chip_size, chip_size))

        ## Geting the stored parameters inside of the hook
        aux_in_channels = self.hook.stored.shape[1]
        del self.hook.stored                     
        self.aux_logits = nn.Conv2d(aux_in_channels, num_classes, kernel_size=1)       
Esempio n. 5
0
    def __init__(self,
                 encoder=None,
                 n_classes=2,
                 last_filters=32,
                 imsize=(256, 256),
                 y_range=None,
                 **kwargs):

        self.n_classes = n_classes

        layers = nn.ModuleList()

        # Encoder
        sfs_szs = model_sizes(encoder, size=imsize)
        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
        layers.append(encoder)

        x = dummy_eval(encoder, imsize).detach()

        self.hc_hooks = []
        hc_c = []

        ni = sfs_szs[-1][1]
        middle_conv = nn.Sequential(conv_layer(ni, ni * 2),
                                    conv_layer(ni * 2, ni)).eval()
        x = middle_conv(x)
        layers.extend([batchnorm_2d(ni), nn.ReLU(), middle_conv])

        # self.hc_hooks = [Hook(layers[-1], _hook_inner, detach=False)]
        # hc_c = [x.shape[1]]

        # Decoder
        n_filters = [64, 128, 256, 512]
        n = len(n_filters)
        is_deconv = True

        for i, idx in enumerate(sfs_idxs[:-1]):
            in_c, out_c = int(n_filters[n - i - 1] +
                              n_filters[n - i - 2]) // 2, int(sfs_szs[idx][1])

            dec_bloc = DecoderBlock(in_c, out_c, self.sfs[i], is_deconv,
                                    True).eval()
            layers.append(dec_bloc)

            x = dec_bloc(x)

            self.hc_hooks.append(Hook(layers[-1], _hook_inner, detach=False))
            hc_c.append(x.shape[1])

        ni = x.shape[1]

        layers.append(PixelShuffle_ICNR(n_filters[0], scale=2))

        layers.append(Hcolumns(self.hc_hooks, hc_c))

        fin_block = FinalBlock(ni * (len(hc_c) + 1), last_filters, n_classes)
        layers.append(fin_block)

        if y_range is not None:
            layers.append(SigmoidRange(*y_range))
        super().__init__(*layers)
Esempio n. 6
0
    def __init__(self,
                 encoder: nn.Module,
                 n_classes: int,
                 blur: bool = False,
                 blur_final=True,
                 self_attention: bool = False,
                 y_range: Optional[Tuple[float, float]] = None,
                 last_cross: bool = True,
                 bottle: bool = False,
                 norm_type: Optional[NormType] = NormType.Batch,
                 nf_factor: int = 1,
                 **kwargs):

        nf = 512 * nf_factor
        extra_bn = norm_type == NormType.Spectral
        imsize = (256, 256)
        sfs_szs = model_sizes(encoder, size=imsize)
        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
        x = dummy_eval(encoder, imsize).detach()

        ni = sfs_szs[-1][1]
        middle_conv = nn.Sequential(
            custom_conv_layer(ni,
                              ni * 2,
                              norm_type=norm_type,
                              extra_bn=extra_bn,
                              **kwargs),
            custom_conv_layer(ni * 2,
                              ni,
                              norm_type=norm_type,
                              extra_bn=extra_bn,
                              **kwargs),
        ).eval()
        x = middle_conv(x)
        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]

        for i, idx in enumerate(sfs_idxs):
            not_final = i != len(sfs_idxs) - 1
            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
            do_blur = blur and (not_final or blur_final)
            sa = self_attention and (i == len(sfs_idxs) - 3)

            n_out = nf if not_final else nf // 2

            unet_block = UnetBlockWide(up_in_c,
                                       x_in_c,
                                       n_out,
                                       self.sfs[i],
                                       final_div=not_final,
                                       blur=blur,
                                       self_attention=sa,
                                       norm_type=norm_type,
                                       extra_bn=extra_bn,
                                       **kwargs).eval()
            layers.append(unet_block)
            x = unet_block(x)

        ni = x.shape[1]
        if imsize != sfs_szs[0][-2:]:
            layers.append(PixelShuffle_ICNR(ni, **kwargs))
        if last_cross:
            layers.append(MergeLayer(dense=True))
            ni += in_channels(encoder)
            layers.append(
                res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
        layers += [
            custom_conv_layer(ni,
                              n_classes,
                              ks=1,
                              use_activ=False,
                              norm_type=norm_type)
        ]
        if y_range is not None:
            layers.append(SigmoidRange(*y_range))
        super().__init__(*layers)
Esempio n. 7
0
def cnn_activations_count(model):
    _, ch, h, w = model_sizes(create_body(models.resnet18), (SZ, SZ))[-1]
    return ch * h * w
Esempio n. 8
0
    def __init__(self,
                 data,
                 grids=[4, 2, 1],
                 zooms=[0.7, 1., 1.3],
                 ratios=[[1., 1.], [1., 0.5], [0.5, 1.]],
                 backbone=None,
                 drop=0.3,
                 bias=-4.,
                 focal_loss=False,
                 pretrained_path=None,
                 location_loss_factor=None):

        super().__init__()

        self._device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')

        # assert (location_loss_factor is not None) or ((location_loss_factor > 0) and (location_loss_factor < 1)),
        if location_loss_factor is not None:
            if not ((location_loss_factor > 0) and (location_loss_factor < 1)):
                raise Exception(
                    '`location_loss_factor` should be greater than 0 and less than 1'
                )
        self.location_loss_factor = location_loss_factor

        if not HAS_FASTAI:
            _raise_fastai_import_error()

        if backbone is None:
            self._backbone = models.resnet34
        elif type(backbone) is str:
            self._backbone = getattr(models, backbone)
        else:
            self._backbone = backbone

        self._create_anchors(grids, zooms, ratios)

        feature_sizes = model_sizes(create_body(self._backbone),
                                    size=(data.chip_size, data.chip_size))
        num_features = feature_sizes[-1][-1]
        num_channels = feature_sizes[-1][1]

        ssd_head = SSDHead(grids,
                           self._anchors_per_cell,
                           data.c,
                           num_features=num_features,
                           drop=drop,
                           bias=bias,
                           num_channels=num_channels)

        self._data = data
        self.learn = create_cnn(data=data,
                                arch=self._backbone,
                                custom_head=ssd_head)
        self.learn.model = self.learn.model.to(self._device)

        if pretrained_path is not None:
            self.load(pretrained_path)

        if focal_loss:
            self._loss_f = FocalLoss(data.c)
        else:
            self._loss_f = BCE_Loss(data.c)

        self.learn.loss_func = self._ssd_loss
def _pspnet_unet(num_classes, backbone_fn, chip_size=224, pyramid_sizes=(1, 2, 3, 6), pretrained=True):
    """
    Function which returns PPM module attached to backbone which is then used to form the Unet.
    """      
    if getattr(backbone_fn, '_is_multispectral', False):
        backbone = create_body(backbone_fn, pretrained=pretrained, cut=_get_backbone_meta(backbone_fn.__name__)['cut'])
    else:
        backbone = create_body(backbone_fn, pretrained=pretrained)
    
    backbone_name = backbone_fn.__name__

    ## Support for different backbones
    if "densenet" in backbone_name or "vgg" in backbone_name:
        hookable_modules = list(backbone.children())[0]
    else:
        hookable_modules = list(backbone.children())
    
    if "vgg" in backbone_name:
        modify_dilation_index = -5
    else:
        modify_dilation_index = -2
        
    if backbone_name == 'resnet18' or backbone_name == 'resnet34':
        module_to_check = 'conv' 
    else:
        module_to_check = 'conv2'
    
    custom_idx = 0
    for i, module in enumerate(hookable_modules[modify_dilation_index:]): 
        dilation = 2 * (i + 1)
        padding = 2 * (i + 1)
        # padding = 1
        for n, m in module.named_modules():
            if module_to_check in n:
                m.dilation, m.padding, m.stride = (dilation, dilation), (padding, padding), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)                    
                
        if "vgg" in backbone_fn.__name__:
            if isinstance(module, nn.Conv2d):
                dilation = 2 * (custom_idx + 1)
                padding = 2 * (custom_idx + 1)
                module.dilation, module.padding, module.stride = (dilation, dilation), (padding, padding), (1, 1)
                custom_idx += 1
    
    ## returns the size of various activations
    feature_sizes = model_sizes(backbone, size=(chip_size, chip_size))

    ## Get number of channels in the last layer
    num_channels = feature_sizes[-1][1]

    penultimate_channels = num_channels / len(pyramid_sizes)
    ppm = _PyramidPoolingModule(num_channels, int(penultimate_channels), pyramid_sizes)

    in_final = int(penultimate_channels) * len(pyramid_sizes) + num_channels

    # Reduce channel size after pyramid pooling module to avoid CUDA OOM error.
    final_conv = nn.Conv2d(in_channels=in_final, out_channels=512, kernel_size=3, padding=1)

    ## To make Dynamic Unet work as it expects a backbone which can be indexed.
    if "densenet" in backbone_name or "vgg" in backbone_name:
        backbone = backbone[0]
    layers = [*backbone, ppm, final_conv]
    return nn.Sequential(*layers)
    def __init__(self, num_classes, backbone_fn, chip_size=224, pyramid_sizes=(1, 2, 3, 6), pretrained=True):
        super(PSPNet, self).__init__()        
        
        if getattr(backbone_fn, '_is_multispectral', False):
            self.backbone = create_body(backbone_fn, pretrained=pretrained, cut=_get_backbone_meta(backbone_fn.__name__)['cut'])
        else:
            self.backbone = create_body(backbone_fn, pretrained=pretrained)
        
        backbone_name = backbone_fn.__name__

        ## Support for different backbones
        if "densenet" in backbone_name or "vgg" in backbone_name:
            hookable_modules = list(self.backbone.children())[0]
        else:
            hookable_modules = list(self.backbone.children())
        
        if "vgg" in backbone_name:
            modify_dilation_index = -5
        else:
            modify_dilation_index = -2
            
        if backbone_name == 'resnet18' or backbone_name == 'resnet34':
            module_to_check = 'conv' 
        else:
            module_to_check = 'conv2'
        
        ## Hook at the index where we need to get the auxillary logits out
        self.hook = hook_output(hookable_modules[modify_dilation_index])
        
        custom_idx = 0
        for i, module in enumerate(hookable_modules[modify_dilation_index:]): 
            dilation = 2 * (i + 1)
            padding = 2 * (i + 1)
            for n, m in module.named_modules():
                if module_to_check in n:
                    m.dilation, m.padding, m.stride = (dilation, dilation), (padding, padding), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)                    
                    
            if "vgg" in backbone_fn.__name__:
                if isinstance(module, nn.Conv2d):
                    dilation = 2 * (custom_idx + 1)
                    padding = 2 * (custom_idx + 1)
                    module.dilation, module.padding, module.stride = (dilation, dilation), (padding, padding), (1, 1)
                    custom_idx += 1
        
        ## returns the size of various activations
        feature_sizes = model_sizes(self.backbone, size=(chip_size, chip_size))

        ## Geting the stored parameters inside of the hook
        aux_in_channels = self.hook.stored.shape[1]

        ## Get number of channels in the last layer
        num_channels = feature_sizes[-1][1]

        penultimate_channels = num_channels / len(pyramid_sizes)
        self.ppm = _PyramidPoolingModule(num_channels, int(penultimate_channels), pyramid_sizes)
        
        
        self.final = nn.Sequential(
            ## To handle case when the length of pyramid_sizes is odd
            nn.Conv2d(int(penultimate_channels) * len(pyramid_sizes) + num_channels, math.ceil(penultimate_channels), kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(math.ceil(penultimate_channels)),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(math.ceil(penultimate_channels), num_classes, kernel_size=1)
        )
        
        
        self.aux_logits = nn.Conv2d(aux_in_channels, num_classes, kernel_size=1)
        
        initialize_weights(self.aux_logits)
        initialize_weights(self.ppm, self.final)
Esempio n. 11
0
    def __init__(self,
                 data,
                 grids=None,
                 zooms=[1.],
                 ratios=[[1., 1.]],
                 backbone=None,
                 drop=0.3,
                 bias=-4.,
                 focal_loss=False,
                 pretrained_path=None,
                 location_loss_factor=None,
                 ssd_version=2,
                 backend='pytorch'):

        super().__init__(data, backbone)

        self._backend = backend
        if self._backend == 'tensorflow':
            self._intialize_tensorflow(data, grids, zooms, ratios, backbone,
                                       drop, bias, pretrained_path,
                                       location_loss_factor)
        else:
            # assert (location_loss_factor is not None) or ((location_loss_factor > 0) and (location_loss_factor < 1)),
            if not ssd_version in [1, 2]:
                raise Exception("ssd_version can be only [1,2]")

            if location_loss_factor is not None:
                if not ((location_loss_factor > 0) and
                        (location_loss_factor < 1)):
                    raise Exception(
                        '`location_loss_factor` should be greater than 0 and less than 1'
                    )
            self.location_loss_factor = location_loss_factor

            self._code = code
            self.ssd_version = ssd_version

            backbone_cut = None
            backbone_split = None

            if hasattr(self, '_orig_backbone'):
                self._backbone_ms = self._backbone
                self._backbone = self._orig_backbone
                _backbone_meta = cnn_config(self._orig_backbone)
                backbone_cut = _backbone_meta['cut']
                backbone_split = _backbone_meta['split']

            if backbone is None:
                self._backbone = models.resnet34
                backbone_name = 'res'
            elif type(backbone) is str:
                self._backbone = getattr(models, backbone)
                backbone_name = backbone[:3]
            else:
                self._backbone = backbone
                backbone_name = 'custom'

            if not self._check_backbone_support(self._backbone):
                raise Exception(
                    f"Enter only compatible backbones from {', '.join(self.supported_backbones)}"
                )

            if self._backbone == models.mobilenet_v2:
                backbone_cut = -1
                backbone_split = _mobilenet_split

            if ssd_version == 1:
                if grids == None:
                    grids = [4, 2, 1]

                self._create_anchors(grids, zooms, ratios)

                feature_sizes = model_sizes(create_body(self._backbone,
                                                        cut=backbone_cut),
                                            size=(data.chip_size,
                                                  data.chip_size))
                num_features = feature_sizes[-1][-1]
                num_channels = feature_sizes[-1][1]

                ssd_head = SSDHead(grids,
                                   self._anchors_per_cell,
                                   data.c,
                                   num_features=num_features,
                                   drop=drop,
                                   bias=bias,
                                   num_channels=num_channels)
            elif ssd_version == 2:

                # find bounding boxes height and width

                if grids is None:
                    logger.info("Computing optimal grid size...")
                    hw = data.height_width
                    hw = np.array(hw)

                    # find most suitable centroids for dataset
                    centroid = kmeans(hw, 1)
                    avg = avg_iou(hw, centroid)

                    for num_anchor in range(2, 5):
                        new_centroid = kmeans(hw, num_anchor)
                        new_avg = avg_iou(hw, new_centroid)
                        if (new_avg - avg) < 0.05:
                            break
                        avg = new_avg
                        centroid = new_centroid.copy()

                    # find grid size

                    grids = list(
                        map(
                            int,
                            map(
                                round, data.chip_size /
                                np.sort(np.max(centroid, axis=1)))))
                    grids = list(set(grids))
                    grids.sort(reverse=True)
                    if grids[-1] == 0:
                        grids[-1] = 1
                    grids = list(set(grids))

                self._create_anchors(grids, zooms, ratios)

                feature_sizes = model_sizes(create_body(self._backbone,
                                                        cut=backbone_cut),
                                            size=(data.chip_size,
                                                  data.chip_size))
                num_features = feature_sizes[-1][-1]
                num_channels = feature_sizes[-1][1]

                if grids[0] > 8 and abs(num_features - grids[0]
                                        ) > 4 and backbone_name == 'res':
                    num_features = feature_sizes[-2][-1]
                    num_channels = feature_sizes[-2][1]
                    backbone_cut = -3
                ssd_head = SSDHeadv2(grids,
                                     self._anchors_per_cell,
                                     data.c,
                                     num_features=num_features,
                                     drop=drop,
                                     bias=bias,
                                     num_channels=num_channels)

            else:
                raise Exception('SSDVersion can only be 1 or 2')

            if hasattr(self, '_backbone_ms'):
                self._orig_backbone = self._backbone
                self._backbone = self._backbone_ms

            self.learn = cnn_learner(data=data,
                                     base_arch=self._backbone,
                                     cut=backbone_cut,
                                     split_on=backbone_split,
                                     custom_head=ssd_head)
            self._arcgis_init_callback()  # make first conv weights learnable
            self.learn.model = self.learn.model.to(self._device)

            if focal_loss:
                self._loss_f = FocalLoss(data.c)
            else:
                self._loss_f = BCE_Loss(data.c)
            self.learn.loss_func = self._ssd_loss

            _set_multigpu_callback(self)
            if pretrained_path is not None:
                self.load(pretrained_path)
Esempio n. 12
0
        
    def forward(self, x:Tensor):
        n = len(self.hooks)
        out = [F.interpolate(self.hooks[i].stored if self.factorization is None
            else self.factorization[i](self.hooks[i].stored), scale_factor=2**(self.n-i),
            mode='bilinear',align_corners=False) for i in range(self.n)] + [x]
        return torch.cat(out, dim=1)

class DynamicUnet_Hcolumns(SequentialEx):
    "Create a U-Net from a given architecture."
    def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, 
                 self_attention:bool=False,
                 y_range:Optional[Tuple[float,float]]=None,
                 last_cross:bool=True, bottle:bool=False, **kwargs):
        imsize = (args.size, args.size)
        sfs_szs = model_sizes(encoder, size=imsize)
        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
        x = dummy_eval(encoder, imsize).detach()

        ni = sfs_szs[-1][1]
        middle_conv = nn.Sequential(conv_layer(ni, ni*2, **kwargs),
                                    conv_layer(ni*2, ni, **kwargs)).eval()
        x = middle_conv(x)
        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]

        self.hc_hooks = [Hook(layers[-1], _hook_inner, detach=False)]
        hc_c = [x.shape[1]]
        
        for i,idx in enumerate(sfs_idxs):
            not_final = i!=len(sfs_idxs)-1
Esempio n. 13
0
    def __init__(
        self,
        encoder: nn.Module,
        n_classes: int,
        blur: bool = False,
        blur_final=True,
        self_attention: bool = False,
        y_range: Optional[Tuple[float, float]] = None,
        last_cross: bool = True,
        bottle: bool = False,
        small=True,
        **kwargs,
    ):
        imsize = (256, 256)
        # for resnet50 ... but memory not enough...
        # sfs_szs = [(1, 64, 128, 128), (1, 64, 128, 128), (1, 64, 1...512, 32, 32), (1, 1024, 16, 16), (1, 2048, 8, 8)]
        # sfs_idxs = [6, 5, 4, 2]  #? 3?
        sfs_szs = model_sizes(encoder, size=imsize)
        # for resnext50_32x4d
        # [torch.Size([1, 64, 64, 64]), torch.Size([1, 64, 64, 64]), torch.Size([1, 64, 64, 64]), torch.Size([1, 64, 32, 32]), torch.Size([1, 256, 32, 32]), torch.Size([1, 512, 16, 16]), torch.Size([1, 1024, 8, 8]), torch.Size([1, 2048, 4, 4])]
        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
        # if small: sfs_idxs = sfs_idxs[-3:] (need to do double upscale)
        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
        x = dummy_eval(encoder, imsize).detach()

        ni = sfs_szs[-1][1]
        if small:
            middle_conv_size_down_scale = 2
            middle_conv = conv_layer(ni, ni // middle_conv_size_down_scale,
                                     **kwargs).eval()
        else:
            middle_conv_size_scale = 2
            middle_conv = nn.Sequential(
                conv_layer(ni, ni * middle_conv_size_scale, **kwargs),
                conv_layer(ni * middle_conv_size_scale, ni, **kwargs),
            ).eval()
        x = middle_conv(x)
        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]

        if small:
            self.hc_hooks = []
            hc_c = []
        else:
            self.hc_hooks = [Hook(layers[-1], _hook_inner, detach=False)]
            hc_c = [x.shape[1]]

        for i, idx in enumerate(sfs_idxs):
            final_unet_flag = i == len(sfs_idxs) - 1
            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
            do_blur = blur and (final_unet_flag or blur_final)
            sa = self_attention and (i == len(sfs_idxs) - 3)
            unet_block_class = UnetBlockSmall if small else UnetBlock
            unet_block = unet_block_class(
                up_in_c,
                x_in_c,
                self.sfs[i],
                final_div=final_unet_flag,
                blur=blur,
                self_attention=sa,
                **kwargs,
            ).eval()
            print(unet_block)
            layers.append(unet_block)
            x = unet_block(x)
            # added for hypercolumns, two line
            self.hc_hooks.append(Hook(layers[-1], _hook_inner, detach=False))
            hc_c.append(x.shape[1])

        ni = x.shape[1]
        if imsize != sfs_szs[0][-2:]:
            layers.append(PixelShuffle_ICNR(ni, **kwargs))
        if last_cross:
            layers.append(MergeLayer(dense=True))
            ni += in_channels(encoder)
            layers.append(res_block(ni, bottle=bottle, **kwargs))
        # added for hypercolumns, two line
        hc_c.append(ni)
        layers.append(Hcolumns(self.hc_hooks, hc_c))
        layers += [
            conv_layer(ni * len(hc_c),
                       n_classes,
                       ks=1,
                       use_activ=False,
                       **kwargs)
        ]
        if y_range is not None:
            layers.append(SigmoidRange(*y_range))
        super().__init__(*layers)