def rascpc(**kwargs): model = UnetGenerator(4, 3, is_attention_layer=True, attention_model=RASC_PC) model.apply(weights_init_kaiming) return model
def nlnet(**kwargs): model = UnetGenerator(4, 3, is_attention_layer=True, attention_model=NLWapper) model.apply(weights_init_kaiming) return model
def cbam(**kwargs): model = UnetGenerator(4, 3, is_attention_layer=True, attention_model=CBAMConnect) model.apply(weights_init_kaiming) return model
def rascv1(**kwargs): # Splicing Region: features -> GlobalAttentionModule -> CNN -> * SplicingSmoothMask -> # mixed Region: faetures -> GlobalAttentionModule ----------⬆ # Background Region: faetures -> GlobalAttentionModule -> * ReversedSmoothMask -> model = UnetGenerator(4, 3, is_attention_layer=True, attention_model=RASC) model.apply(weights_init_kaiming) return model
def rascv2(**kwargs): model = UnetGenerator(4, 3, is_attention_layer=True, attention_model=RASC, basicblock=MinimalUnetV2) model.apply(weights_init_kaiming) return model
def pconv(**kwargs): # unet without mask. model = UnetGenerator(4, 3, is_attention_layer=True, attention_model=PCBlock) model.apply(weights_init_kaiming) return model
def maskedurasc(**kwargs): # learning without mask based on RASCV2. model = UnetGenerator(3, 3, is_attention_layer=True, attention_model=MaskedURASC, basicblock=MaskedMinimalUnetV2) model.apply(weights_init_kaiming) return model
def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1, out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True, transpose=True, concat=True, transfer_data=True, long_skip=False, s2am='unet', use_coarser=True,no_stage2=False): super(UnetVMS2AMv4, self).__init__() self.transfer_data = transfer_data self.shared = shared_depth self.optimizer_encoder, self.optimizer_image, self.optimizer_vm = None, None, None self.optimizer_mask, self.optimizer_shared = None, None if type(blocks) is not tuple: blocks = (blocks, blocks, blocks, blocks, blocks) if not transfer_data: concat = False self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0], start_filters=start_filters, residual=residual, batch_norm=batch_norm,norm=nn.InstanceNorm2d,act=F.leaky_relu) self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), out_channels=out_channels_image, depth=depth - shared_depth, blocks=blocks[1], residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat,norm=nn.InstanceNorm2d) self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), out_channels=out_channels_mask, depth=depth - shared_depth, blocks=blocks[2], residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat,norm=nn.InstanceNorm2d) self.vm_decoder = None if use_vm_decoder: self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), out_channels=out_channels_image, depth=depth - shared_depth, blocks=blocks[3], residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat,norm=nn.InstanceNorm2d) self.shared_decoder = None self.use_coarser = use_coarser self.long_skip = long_skip self.no_stage2 = no_stage2 self._forward = self.unshared_forward if self.shared != 0: self._forward = self.shared_forward self.shared_decoder = UnetDecoderDatt(in_channels=start_filters * 2 ** (depth - 1), out_channels=start_filters * 2 ** (depth - shared_depth - 1), depth=shared_depth, blocks=blocks[4], residual=residual, batch_norm=batch_norm, transpose=transpose, concat=concat, is_final=False,norm=nn.InstanceNorm2d) if s2am == 'unet': self.s2am = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2) elif s2am == 'vm': self.s2am = VMSingle(4) elif s2am == 'vms2am': self.s2am = VMSingleS2AM(4,down=ResDownNew,up=ResUpNew)
def unet(**kwargs): model = UnetGenerator(3, 3) model.apply(weights_init_kaiming) return model
def unet72(**kwargs): # just original unet model = UnetGenerator(4, 3, ngf=72) model.apply(weights_init_kaiming) return model
def uno(**kwargs): # unet without mask. model = UnetGenerator(3, 3, is_attention_layer=True, attention_model=UNO) model.apply(weights_init_kaiming) return model