def __init__(self, layers=5): super(Net, self).__init__() self._down2 = nn.ModuleList([Down2(3, 64)] + [Down2(64, 64) for _ in range(layers - 2)]) self.down2 = list(self._down2) + [identity] # Branches branches = (Branch1, Branch2, Branch3, Branch4, Branch5, Branch6)[:layers] self.branches = nn.ModuleList(f() for f in branches) self.scales = nn.ModuleList(ScaleLayer() for _ in range(layers)) initConvParameters(self)
def __init__(self, deepFs, combine=None, cat=True, u2=identity, in_channels=64): super(Branch, self).__init__() self.inputF = namedSequential(('conv_input', Conv3x3(in_channels, 64)), ('relu', nn.PReLU())) if cat: deepFs = [CAT(128)] + list(deepFs) self.shallowF = nn.Sequential(*(CARB(64) for _ in range(5))) else: self.shallowF = None self.deepF = nn.Sequential(*deepFs) self.combineF = namedSequential(*combine) if combine else None self.u2 = u2 initConvParameters(self)
def __init__(self, scaleLayers, strides, non_local=True): super(Branch, self).__init__() self.conv_input = Conv3x3(64, 64) self.relu = nn.PReLU() self.convt_F = nn.ModuleList(CARB(64) for _ in strides) # style encode self.s_conv = nn.ModuleList(Conv3x3(64, 64, stride=k) for k in strides) self.non_local = Nonlocal_CA(in_feat=64, inter_feat=64 // 8, reduction=8, sub_sample=False, bn_layer=False) if non_local else identity self.u = nn.Sequential(*(upsample_block(64, 256) for _ in range(scaleLayers))) self.convt_shape1 = Conv3x3(64, 3) initConvParameters(self)