def __init__(self, channel=128, reduction=2, ksize=3, scale=3, stride=1, softmax_scale=10, average=True, conv=common.default_conv): super(NonLocalAttention, self).__init__() self.conv_match1 = common.BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU()) self.conv_match2 = common.BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU()) self.conv_assembly = common.BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU())
def __init__(self, depth=12, rgb_range=255, n_colors=3, n_feats=64, scale=4, conv=common.default_conv, **kwargs): super(CSNLN_Model, self).__init__() # n_convblock = args.n_convblocks self.depth = depth kernel_size = 3 rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std) # define head module m_head = [ common.BasicBlock(conv, n_colors, n_feats, kernel_size, stride=1, bias=True, bn=False, act=nn.PReLU()), common.BasicBlock(conv, n_feats, n_feats, kernel_size, stride=1, bias=True, bn=False, act=nn.PReLU()) ] # define Self-Exemplar Mining Cell self.SEM = RecurrentProjection(n_feats, scale=scale) # define tail module m_tail = [ nn.Conv2d(n_feats * self.depth, n_colors, kernel_size, padding=(kernel_size // 2)) ] self.add_mean = common.MeanShift(rgb_range, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*m_head) self.tail = nn.Sequential(*m_tail)
def basic_block(in_channels, out_channels, act): return common.BasicBlock(conv, in_channels, out_channels, kernel_size, bias=True, bn=False, act=act)
def __init__(self, channel=128, reduction=2, ksize=3, scale=3, stride=1, softmax_scale=10, average=True, conv=common.default_conv): super(CrossScaleAttention, self).__init__() self.ksize = ksize self.stride = stride self.softmax_scale = softmax_scale self.scale = scale self.average = average escape_NaN = torch.FloatTensor([1e-4]) self.register_buffer('escape_NaN', escape_NaN) self.conv_match_1 = common.BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU()) self.conv_match_2 = common.BasicBlock(conv, channel, channel // reduction, 1, bn=False, act=nn.PReLU()) self.conv_assembly = common.BasicBlock(conv, channel, channel, 1, bn=False, act=nn.PReLU())
def __init__(self, in_channel, kernel_size=3, scale=2, conv=common.default_conv): super(RecurrentProjection, self).__init__() self.scale = scale stride_conv_ksize, stride, padding = { 2: (6, 2, 2), 3: (9, 3, 3), 4: (6, 2, 2) }[scale] self.multi_source_projection = MultisourceProjection( in_channel, kernel_size=kernel_size, scale=scale, conv=conv) self.down_sample_1 = nn.Sequential(*[ nn.Conv2d(in_channel, in_channel, stride_conv_ksize, stride=stride, padding=padding), nn.PReLU() ]) if scale != 4: self.down_sample_2 = nn.Sequential(*[ nn.Conv2d(in_channel, in_channel, stride_conv_ksize, stride=stride, padding=padding), nn.PReLU() ]) self.error_encode = nn.Sequential(*[ nn.ConvTranspose2d(in_channel, in_channel, stride_conv_ksize, stride=stride, padding=padding), nn.PReLU() ]) self.post_conv = common.BasicBlock(conv, in_channel, in_channel, kernel_size, stride=1, bias=True, act=nn.PReLU()) if scale == 4: self.multi_source_projection_2 = MultisourceProjection( in_channel, kernel_size=kernel_size, scale=scale, conv=conv) self.down_sample_3 = nn.Sequential(*[ nn.Conv2d(in_channel, in_channel, 8, stride=4, padding=2), nn.PReLU() ]) self.down_sample_4 = nn.Sequential(*[ nn.Conv2d(in_channel, in_channel, 8, stride=4, padding=2), nn.PReLU() ]) self.error_encode_2 = nn.Sequential(*[ nn.ConvTranspose2d( in_channel, in_channel, 8, stride=4, padding=2), nn.PReLU() ])