コード例 #1
0
ファイル: backbone_unet.py プロジェクト: vinthony/s2am
def rascpc(**kwargs):
    model = UnetGenerator(4,
                          3,
                          is_attention_layer=True,
                          attention_model=RASC_PC)
    model.apply(weights_init_kaiming)
    return model
コード例 #2
0
ファイル: backbone_unet.py プロジェクト: vinthony/s2am
def nlnet(**kwargs):
    model = UnetGenerator(4,
                          3,
                          is_attention_layer=True,
                          attention_model=NLWapper)
    model.apply(weights_init_kaiming)
    return model
コード例 #3
0
ファイル: backbone_unet.py プロジェクト: vinthony/s2am
def cbam(**kwargs):
    model = UnetGenerator(4,
                          3,
                          is_attention_layer=True,
                          attention_model=CBAMConnect)
    model.apply(weights_init_kaiming)
    return model
コード例 #4
0
ファイル: backbone_unet.py プロジェクト: vinthony/s2am
def rascv1(**kwargs):
    # Splicing Region: features -> GlobalAttentionModule -> CNN -> * SplicingSmoothMask ->
    # mixed Region: faetures -> GlobalAttentionModule ----------⬆
    # Background Region: faetures -> GlobalAttentionModule -> * ReversedSmoothMask ->
    model = UnetGenerator(4, 3, is_attention_layer=True, attention_model=RASC)
    model.apply(weights_init_kaiming)
    return model
コード例 #5
0
def rascv2(**kwargs):
    model = UnetGenerator(4,
                          3,
                          is_attention_layer=True,
                          attention_model=RASC,
                          basicblock=MinimalUnetV2)
    model.apply(weights_init_kaiming)
    return model
コード例 #6
0
ファイル: backbone_unet.py プロジェクト: vinthony/s2am
def pconv(**kwargs):
    # unet without mask.
    model = UnetGenerator(4,
                          3,
                          is_attention_layer=True,
                          attention_model=PCBlock)
    model.apply(weights_init_kaiming)
    return model
コード例 #7
0
ファイル: backbone_unet.py プロジェクト: vinthony/s2am
def maskedurasc(**kwargs):
    # learning without mask based on RASCV2.
    model = UnetGenerator(3,
                          3,
                          is_attention_layer=True,
                          attention_model=MaskedURASC,
                          basicblock=MaskedMinimalUnetV2)
    model.apply(weights_init_kaiming)
    return model
コード例 #8
0
    def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1,
                 out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True,
                 transpose=True, concat=True, transfer_data=True, long_skip=False, s2am='unet', use_coarser=True,no_stage2=False):
        super(UnetVMS2AMv4, self).__init__()
        self.transfer_data = transfer_data
        self.shared = shared_depth
        self.optimizer_encoder,  self.optimizer_image, self.optimizer_vm = None, None, None
        self.optimizer_mask, self.optimizer_shared = None, None
        if type(blocks) is not tuple:
            blocks = (blocks, blocks, blocks, blocks, blocks)
        if not transfer_data:
            concat = False
        self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0],
                                    start_filters=start_filters, residual=residual, batch_norm=batch_norm,norm=nn.InstanceNorm2d,act=F.leaky_relu)
        self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
                                          out_channels=out_channels_image, depth=depth - shared_depth,
                                          blocks=blocks[1], residual=residual, batch_norm=batch_norm,
                                          transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
        self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
                                         out_channels=out_channels_mask, depth=depth - shared_depth,
                                         blocks=blocks[2], residual=residual, batch_norm=batch_norm,
                                         transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
        self.vm_decoder = None
        if use_vm_decoder:
            self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
                                           out_channels=out_channels_image, depth=depth - shared_depth,
                                           blocks=blocks[3], residual=residual, batch_norm=batch_norm,
                                           transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
        self.shared_decoder = None
        self.use_coarser = use_coarser
        self.long_skip = long_skip
        self.no_stage2 = no_stage2
        self._forward = self.unshared_forward
        if self.shared != 0:
            self._forward = self.shared_forward
            self.shared_decoder = UnetDecoderDatt(in_channels=start_filters * 2 ** (depth - 1),
                                               out_channels=start_filters * 2 ** (depth - shared_depth - 1),
                                               depth=shared_depth, blocks=blocks[4], residual=residual,
                                               batch_norm=batch_norm, transpose=transpose, concat=concat,
                                               is_final=False,norm=nn.InstanceNorm2d)

        if s2am == 'unet':
            self.s2am = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2)
        elif s2am == 'vm':
            self.s2am = VMSingle(4)
        elif s2am == 'vms2am':
            self.s2am = VMSingleS2AM(4,down=ResDownNew,up=ResUpNew)
コード例 #9
0
def unet(**kwargs):
    model = UnetGenerator(3, 3)
    model.apply(weights_init_kaiming)
    return model
コード例 #10
0
ファイル: backbone_unet.py プロジェクト: vinthony/s2am
def unet72(**kwargs):
    # just original unet
    model = UnetGenerator(4, 3, ngf=72)
    model.apply(weights_init_kaiming)
    return model
コード例 #11
0
ファイル: backbone_unet.py プロジェクト: vinthony/s2am
def uno(**kwargs):
    # unet without mask.
    model = UnetGenerator(3, 3, is_attention_layer=True, attention_model=UNO)
    model.apply(weights_init_kaiming)
    return model