Ejemplo n.º 1
0
 def __init__(self,
              encoder: nn.Module,
              n_classes,
              final_bias=0.,
              chs=256,
              n_anchors=9,
              flatten=True):
     super().__init__()
     self.n_classes, self.flatten = n_classes, flatten
     imsize = (256, 256)
     sfs_szs = model_sizes(encoder, size=imsize)
     sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
     self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
     self.encoder = encoder
     self.c5top5 = conv2d(sfs_szs[-1][1], chs, ks=1, bias=True)
     self.c5top6 = conv2d(sfs_szs[-1][1], chs, stride=2, bias=True)
     self.p6top7 = nn.Sequential(nn.ReLU(),
                                 conv2d(chs, chs, stride=2, bias=True))
     self.merges = nn.ModuleList([
         LateralUpsampleMerge(chs, sfs_szs[idx][1], hook)
         for idx, hook in zip(sfs_idxs[-2:-4:-1], self.sfs[-2:-4:-1])
     ])
     self.smoothers = nn.ModuleList(
         [conv2d(chs, chs, 3, bias=True) for _ in range(3)])
     self.classifier = self._head_subnet(n_classes,
                                         n_anchors,
                                         final_bias,
                                         chs=chs)
     self.box_regressor = self._head_subnet(4, n_anchors, 0., chs=chs)
Ejemplo n.º 2
0
    def __init__(self, encoder:nn.Module, n_classes:int,
                 y_range:Optional[Tuple[float,float]]=None, skip_connections=True,
                 **kwargs):

        encoder[2] = encoder[0:3]
        encoder = nn.Sequential(*list(encoder.children())[2:])

        attented_layers = []
        filter = []

        imsize = (256,256)
        sfs_szs = model_sizes(encoder, size=imsize)
        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
        sfs_idxs = sfs_idxs[:-1]

        attented_layers.extend([encoder[ind] for ind in sfs_idxs[::-1]])
        attented_layers.append(encoder[-1])

        filter.extend([sfs_szs[ind][1] for ind in sfs_idxs[::-1]])

        x = dummy_eval(encoder, imsize).detach()

        ni = sfs_szs[-1][1]
        filter.append(ni)

        middle_conv_enc = conv_layer(ni, ni*2, **kwargs).eval()
        middle_conv_dec = conv_layer(ni*2, ni, **kwargs).eval()

        x = middle_conv_enc(x)
        x = middle_conv_dec(x)

        layers = list(encoder)
        layers = layers + [batchnorm_2d(ni), nn.ReLU(), middle_conv_enc, middle_conv_dec]

        attented_layers.append(middle_conv_enc)
        attented_layers.append(middle_conv_dec)

        filter.extend([ni*2,ni*2,ni])

        # sfs_idxs = sfs_idxs[:-2]
        for i,idx in enumerate(sfs_idxs):
            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
            if skip_connections:
                not_final = not (i!=len(sfs_idxs)-1)
                unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final,
                                       **kwargs).eval()
            else:
                unet_block = UnetBlockWithoutSkipConnection(up_in_c, **kwargs).eval()


            layers.append(unet_block)
            x = unet_block(x)

            attented_layers.append(layers[-1])
            filter.append(x_in_c)   # in for first filter param for attention block

        filter = filter[:-1]

        ni = x.shape[1]

        unet_block_last = UnetBlockWithoutSkipConnection(10,
                                                        # final_div=not_final,
                                                        blur=False, self_attention=False,
                                   **kwargs)

        if imsize != sfs_szs[0][-2:]:
            unet_block_last.shuf = PixelShuffle_ICNR(ni, **kwargs)
        else:
            unet_block_last.shuf = nn.Identity()

        unet_block_last.conv1 = conv_layer(ni, n_classes, ks=1, use_activ=False, **kwargs)
        unet_block_last.conv2 = nn.Identity()
        unet_block_last.relu = nn.Identity()

        layers.append(unet_block_last)
        attented_layers.append(unet_block_last)
        # if skip_connections:
        #     ni = 32
        filter.extend([ni, n_classes])

        if y_range is not None: layers.append(SigmoidRange(*y_range))

        super().__init__(*layers)
        self.attended_layers = attented_layers
        self.filter = filter