def __init__(self, layer, buffer, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True, padding=0, norm='in', activation='elu', pad_type='zero'): super(RConv, self).__init__() self.buffer = buffer self.layer = layer self.conv_1 = conv_block(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, padding, "none", "none", pad_type) self.conv_2 = conv_layer(out_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1) self.fusion = conv_block(out_channels * 2, out_channels, kernel_size=1, norm=norm, activation=activation) self.bn_act = nn.Sequential(nn.InstanceNorm2d(out_channels), nn.ELU())
def __init__(self, layer, buffer, in_channels, out_channels, kernel_size=3, stride=2, dilation=1, groups=1, bias=True, padding=1, output_padding=1, norm='in', activation='elu', pad_type='zero'): """ :param layer: :param buffer: :param in_channels: :param out_channels: :param kernel_size: :param stride: :param dilation: :param groups: :param bias: :param padding: :param output_padding: :param norm: :param activation: :param pad_type: """ super(RDeConv, self).__init__() self.buffer = buffer self.layer = layer self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, dilation=dilation, groups=groups, bias=bias) self.act = _activation(activation) self.norm = _norm(norm, out_channels) self.conv = conv_block(out_channels, out_channels, kernel_size, stride=1, bias=bias, padding=1, pad_type=pad_type, norm=norm, activation=activation) self.fusion = conv_block(out_channels * 2, out_channels, kernel_size=1, norm=norm, activation=activation)
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True, pad_type='zero', norm='none', activation='relu'): super(upconv_block, self).__init__() self.deconv = nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1) self.act = _activation('relu') self.norm = _norm('in', out_channels) self.conv = conv_block(out_channels, out_channels, kernel_size, stride, bias=bias, padding=padding, pad_type=pad_type, norm=norm, activation=activation)
def __init__(self, nf, res_blocks_num=8, dmfb_blocks_num=4, attention_mode="self"): """ :param nf: number of channels :param res_blocks_num: number of Residual blocks :param dmfb_blocks_num: number of DMFB blocks :param attention_mode: Attention mode ("context" or "self") """ super(FineBottleneck, self).__init__() self.attention = SelfAttention(nf, k=8) res_seq = [] for _ in range(res_blocks_num): block = B.ResConv(nf, kernel_size=3, dilation=2, padding=2) # [192, 64, 64] res_seq.append(block) self.res_seq = nn.Sequential(*res_seq) dmfb_seq = [] for _ in range(dmfb_blocks_num): block = DMFB(nf) # [192, 64, 64] dmfb_seq.append(block) self.dmfb_seq = nn.Sequential(*dmfb_seq) self.out = nn.Sequential( conv_block(3 * nf, nf, kernel_size=1, stride=1, padding=0, norm="in", activation="elu", pad_type="zero"), conv_block(nf, nf, kernel_size=3, stride=1, padding=1, norm="in", activation="elu", pad_type="zero"))
def __init__(self, in_channels_r, out_channels_r, in_channels_u, out_channels, kernel_size_in, kernel_size_out, up_stride_in, stride_out, up_padding_in, padding_out, output_padding=0, activation_in='lrelu', activation_out='lrelu', norm_in='bn', norm_out='none'): """ u {l-1} and r {l} new should be of equal sizes to be concatenated :param in_channels_r: in channels for r {l} :param out_channels_r: out channels for r {l-1} :param in_channels_u: in channels for u {l-1} :param out_channels: out channels for (u {l-1}, r {l}) -> r {l-1} :param kernel_size_in: kernel size for r {l}) -> r {l} new :param kernel_size_out: kernel size for (u {l-1}, r {l}) -> r {l-1} :param up_stride_in: stride for transposed conv r {l}) -> r {l} new :param stride_out: stride for (u {l-1}, r {l}) -> r {l-1} :param up_padding_in: padding for transposed convolution r {l} -> r {l} new :param output_padding: output padding for transposed convolution r {l} -> r {l} new :param padding_out: padding for (u {l-1}, r {l}) -> r {l-1} :param activation_in: activation layer for r {l}) -> r {l} new :param activation_out: activation layer for (u {l-1}, r {l}) -> r {l-1} :param norm_in: normalization layer for r {l}) -> r {l} new :param norm_out: normalization layer for (u {l-1}, r {l}) -> r {l-1} """ super(RecovecyBlock, self).__init__() self.in_upconv = upconv_block( in_channels=in_channels_r, out_channels=out_channels_r, kernel_size=kernel_size_in, stride=up_stride_in, padding=up_padding_in, output_padding=output_padding, norm=norm_in, activation=activation_in ) self.out_conv = conv_block( in_channels=out_channels_r + in_channels_u, out_channels=out_channels, kernel_size=kernel_size_out, stride=stride_out, padding=padding_out, norm=norm_out, activation=activation_out )
def __init__(self, in_channels_m, out_channels_m, in_channels_e, out_channels_e, kernel_size, stride, padding, activation='relu', norm='none'): super(SimpleMADF, self).__init__() self.conv_m = conv_block( in_channels=in_channels_m, out_channels=out_channels_m, kernel_size=kernel_size, stride=stride, padding=padding, activation=activation, norm='none') self.conv_e = conv_block( in_channels=in_channels_e + in_channels_m, out_channels=out_channels_e, kernel_size=kernel_size, stride=stride, padding=padding, activation=activation, norm=norm)
def __init__(self, in_channels_1, in_channels_2, out_channels, kernel_size_1, kernel_size_2, stride_1, up_stride_2, padding_1, up_padding_2, output_padding=0, activation_in='relu', activation_out='lrelu', norm_in='bn', norm_out='none'): """ 1 convolution - from f {l-1}{r-1} 2 transposed convolution - from f {l} {k} 3 convolution - the one to be summed with result of f {l}{k} (2) convolution 4 convolution - the one to be producted with result of f {l}{k} (2) convolution :param in_channels_1: Input channels of the 1st convolution layer :param in_channels_2: Input channels of the 2st convolution layer :param out_channels: Output channels of the 1st, 2st, 3rd, 4th convolution layers :param kernel_size_1: Kernel size of 1st convolution layer :param kernel_size_2: Kernel size of 2nd transposed convolution layer :param stride_1: Stride of 1st convolution layer :param up_stride_2: Stride of 2nd transposed convolution layer :param padding_1: Padding of 1st convolution layer :param up_padding_2: Padding of 2nd transposed convolution layer :param output_padding: Output padding of 2nd transposed convolution layer :param activation_in: Activation layer of 1st convolution layer :param activation_out: Activation layer 2nd transposed convolution layer :param norm_in: Normalization layer of 1st convolution layer :param norm_out: Normalization layer of 2nd transposed convolution layer """ super(RefinementBlock, self).__init__() self.conv_1 = conv_block( in_channels=in_channels_1, out_channels=out_channels, kernel_size=kernel_size_1, stride=stride_1, padding=padding_1, norm='none', activation=activation_in ) self.upconv_2 = upconv_block( in_channels=in_channels_2, out_channels=out_channels, kernel_size=kernel_size_2, stride=up_stride_2, padding=up_padding_2, output_padding=output_padding, norm=norm_in, activation='none' ) self.conv_3 = conv_block( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, norm='none', activation='none' ) self.conv_4 = conv_block( in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, norm='none', activation='none' ) self.out_act = _activation(act_type=activation_out) self.out_norm = _norm(norm_type=norm_out, channels=out_channels)
def __init__(self, in_channels_m, out_channels_m, in_channels_e, out_channels_e, kernel_size_m, kernel_size_e, stride_m, stride_e, padding_m, padding_e, activation_m='relu', activation_e='relu', norm_m='none', norm_e='bn', device=torch.device('cpu')): """ :param in_channels_m: input channels of mask layer - m {l-1} :param out_channels_m: output channels of mask layer - m {l} :param in_channels_e: input chanels of image layer - e {l-1} :param out_channels_e: output channels of image layer - e {l} :param kernel_size_m: kernel size for transformation m {l-1} -> m {l} :param kernel_size_e: kernel size for transformation e {l-1} -> e {l} and for kernel creation m {l} -> kernel for e {l} convolution :param stride_m: stride for m {l-1} -> m {l} transformation :param stride_e: stride for e {l-1} -> e {l} transformation :param padding_m: padding for m {l-1} -> m {l} transformation :param padding_e: padding for e {l-1} -> e {l} transformation :param activation_m: activation_m for m {l-1} -> m {l} transformation :param activation_e: activation_e for e {l-1} -> e {l} transformation """ super(MADF, self).__init__() self.in_channels_m = in_channels_m self.out_channels_m = out_channels_m self.in_channels_e = in_channels_e self.out_channels_e = out_channels_e self.kernel_size_e = kernel_size_e self.kernel_size_m = kernel_size_m self.padding_e = padding_e self.padding_m = padding_m self.stride_e = stride_e self.stride_m = stride_m self.activation_m = activation_m self.activation_e = activation_e self.norm_m = norm_m self.norm_e = norm_e self.conv_m = conv_block( in_channels=in_channels_m, out_channels=out_channels_m, kernel_size=kernel_size_m, stride=stride_m, padding=padding_m, activation=activation_m, norm=norm_m) self.conv_filters = conv_block( in_channels=out_channels_m, out_channels=in_channels_e * kernel_size_e * out_channels_e * kernel_size_e, kernel_size=1, stride=1, padding=0, activation="none", norm='none') self.device = device self.activation_e = _activation(activation_e) self.norm = _norm(norm_e, out_channels_e)
def __init__(self, in_nc=4, out_nc=3, res_blocks_num=8, c_num=48, dmfb_blocks_num=8, norm='in', activation='relu', device=None): super(FineGenerator, self).__init__() self.encoder = nn.ModuleList([ # [4, 256, 256] conv_block(in_channels=in_nc, out_channels=c_num, kernel_size=5, stride=1, padding=2, norm=norm, activation=activation), # [c_num, 256, 256] -> decoder conv_block(in_channels=c_num, out_channels=2 * c_num, kernel_size=3, stride=2, padding=1, norm=norm, activation=activation), # [c_num * 2, 128, 128] conv_block(in_channels=2 * c_num, out_channels=2 * c_num, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [c_num * 2, 128, 128] -> decoder conv_block(in_channels=2 * c_num, out_channels=4 * c_num, kernel_size=3, stride=2, padding=1, norm=norm, activation=activation), # [c_num * 4, 64, 64] conv_block(in_channels=4 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [c_num * 4, 64, 64] conv_block(in_channels=4 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [c_num * 4, 64, 64] ]) self.fine_bn = FineBottleneck(c_num=4 * c_num, dmfb_blocks_num=dmfb_blocks_num, res_blocks_num=res_blocks_num, device=device) self.decoder = nn.ModuleList([ # [c_num * 4, 64, 64] conv_block(in_channels=c_num * 4, out_channels=c_num * 4, kernel_size=3, stride=1, padding=1), # [c_num * 4, 64, 64] conv_block(in_channels=c_num * 4, out_channels=c_num * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [c_num * 4, 64, 64] upconv_block(in_channels=c_num * 4, out_channels=c_num * 2, kernel_size=3, stride=2, output_padding=1, padding=get_pad_tp(64, 64, [1, 1], [3, 3], [2, 2], [1, 1]), norm=norm, activation=activation), # encoder 3 + skip [c_num * 2, 128, 128] -> decoder 4 # [c_num * 4, 128, 128] conv_block(in_channels=c_num * 4, out_channels=c_num * 2, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [c_num * 2, 128, 128] upconv_block(in_channels=c_num * 2, out_channels=c_num, kernel_size=3, stride=2, output_padding=1, padding=get_pad_tp(128, 128, [1, 1], [3, 3], [2, 2], [1, 1]), norm=norm, activation=activation), # encoder 1 + skip [c_num, 256, 256] -> decoder 6 # [c_num, 256, 256] conv_block(in_channels=c_num * 2, out_channels=c_num // 2, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [c_num//2, 256, 256] conv_block(in_channels=c_num // 2, out_channels=out_nc, kernel_size=3, stride=1, padding=1, norm='none', activation='tanh'), # [out_nc, 256, 256] ])
def __init__(self, c_num, res_blocks_num=8, dmfb_blocks_num=8, activateion="relu", norm="in", device=None): super(FineBottleneck, self).__init__() dmfb_blocks = [] for i in range(dmfb_blocks_num): # [c_num, 64, 64] dmfb_blocks.append(DMFB(in_channels=c_num)) self.dmfb_seq = nn.Sequential(*dmfb_blocks) res_blocks = [] for i in range(res_blocks_num): # [c_num, 64, 64] block = B.ResConv(channels=c_num, kernel_size=3, dilation=2, padding=2) res_blocks.append(block) self.res_seq = nn.Sequential(*res_blocks) # Contextual attention self.contextual_attention = ContextualAttention(ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10, fuse=True, device=device) self.hypergraph = HypergraphConv(in_channels=c_num, out_channels=c_num, filters=256, edges=256, height=64, width=64) self.out_1 = nn.Sequential( conv_block(in_channels=3 * c_num, out_channels=c_num, kernel_size=1, stride=1, padding=0, norm=norm, activation=activateion, pad_type="zero"), conv_block(in_channels=c_num, out_channels=c_num, kernel_size=3, stride=1, padding=1, norm=norm, activation=activateion, pad_type="zero")) self.out = nn.Sequential( conv_block(in_channels=2 * c_num, out_channels=c_num, kernel_size=1, stride=1, padding=0, norm=norm, activation=activateion, pad_type="zero"), conv_block(in_channels=c_num, out_channels=c_num, kernel_size=3, stride=1, padding=1, norm=norm, activation=activateion, pad_type="zero"))
def __init__(self, in_nc=4, c_num=48, out_nc=3, res_blocks_num=8, norm="in", activation="relu"): super(CoarseGenerator, self).__init__() self.encoder_coarse = nn.ModuleList([ # [4, 256, 256] conv_block(in_nc, c_num, kernel_size=5, stride=1, padding=2, norm=norm, activation=activation), # [c_num, 256, 256] -> decoder_6 conv_block(c_num, c_num * 2, kernel_size=3, stride=2, padding=1, norm=norm, activation=activation), # [c_num * 2, 128, 128] conv_block(c_num * 2, c_num * 2, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [c_num * 2, 128, 128] -> decoder_4 conv_block(c_num * 2, c_num * 4, kernel_size=3, stride=2, padding=1, norm=norm, activation=activation), # [c_num * 4, 64, 64] conv_block(c_num * 4, c_num * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [c_num * 4, 64, 64] conv_block(c_num * 4, c_num * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation) # [c_num * 4, 64, 64]] ]) blocks = [] for _ in range(res_blocks_num): block = B.ResConv(4 * c_num, kernel_size=3, dilation=2, padding=2) # [192, 64, 64] blocks.append(block) self.coarse_bn = nn.Sequential(*blocks) self.decoder_coarse = nn.ModuleList([ conv_block(c_num * 4, c_num * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [c_num * 4, 64, 64] conv_block(c_num * 4, c_num * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [c_num * 4, 64, 64] upconv_block(c_num * 4, c_num * 2, kernel_size=3, stride=2, output_padding=1, padding=get_pad_tp(64, 64, [1, 1], [3, 3], [2, 2], [1, 1]), norm=norm, activation=activation), # [c_num * 2, 128, 128] + skip [c_num * 2, 128, 128] conv_block(c_num * 4, c_num * 2, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [c_num * 2, 128, 128] upconv_block(c_num * 2, c_num, kernel_size=3, stride=2, output_padding=1, padding=get_pad_tp(128, 128, [1, 1], [3, 3], [2, 2], [1, 1]), norm=norm, activation=activation), # [c_num, 256, 256] + skip [c_num, 256, 256] conv_block(c_num * 2, c_num, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [c_num, 256, 256] conv_block(c_num, c_num // 2, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation) # [c_num // 2, 256, 256] ]) self.out_coarse = nn.Sequential( conv_block(c_num // 2, out_nc, 3, stride=1, padding=1, norm='none', activation='tanh'))
def __init__(self, in_nc=4, out_nc=3, nf=48, norm="in", activation="relu", res_blocks_num=8, dmfb_block_num=4, device=torch.device('cuda')): """ :param in_nc: in channels number :param out_nc: out channels number :param nf: number of intermediete features :param res_blocks_num: number of Residual blocks in fine and coarse roots :param dmfb_block_num: nunumber of DMFB blocks in fine route """ super(InpaintingGenerator, self).__init__() # [4, 256, 256] self.encoder_coarse = nn.ModuleList([ # [48, 256, 256] -> decoder_6 conv_block(in_nc, nf, kernel_size=5, stride=1, padding=2, norm=norm, activation=activation), # [96, 128, 128] conv_block(nf, nf * 2, kernel_size=3, stride=2, padding=1, norm=norm, activation=activation), # [96, 128, 128] -> decoder_4 conv_block(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [192, 64, 64] conv_block(nf * 2, nf * 4, kernel_size=3, stride=2, padding=1, norm=norm, activation=activation), # [192, 64, 64] conv_block(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [192, 64, 64] conv_block(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation) ]) self.encoder_fine = nn.ModuleList([ # [48, 256, 256] -> decoder_6 conv_block(in_nc, nf, kernel_size=5, stride=1, padding=2, norm=norm, activation=activation), # [96, 128, 128] conv_block(nf, nf * 2, kernel_size=3, stride=2, padding=1, norm=norm, activation=activation), # [96, 128, 128] -> decoder_4 conv_block(nf * 2, nf * 2, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [192, 64, 64] conv_block(nf * 2, nf * 4, kernel_size=3, stride=2, padding=1, norm=norm, activation=activation), # [192, 64, 64] conv_block(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [192, 64, 64] conv_block(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation) ]) blocks = [] for _ in range(res_blocks_num): block = B.ResConv(4 * nf, kernel_size=3, dilation=2, padding=2) # [192, 64, 64] blocks.append(block) self.coarse = nn.Sequential(*blocks) self.fine = FineBottleneck(4 * nf, res_blocks_num, dmfb_block_num, attention_mode="self") self.decoder_coarse = nn.ModuleList([ # [192, 64, 64] conv_block(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [192, 64, 64] conv_block(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [96, 128, 128] upconv_block(nf * 4, nf * 2, kernel_size=3, stride=2, output_padding=1, padding=get_pad_tp(64, 64, [1, 1], [3, 3], [2, 2], [1, 1]), norm=norm, activation=activation), # [96, 128, 128] conv_block(nf * 4, nf * 2, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [48, 256, 256] upconv_block(nf * 2, nf, kernel_size=3, stride=2, output_padding=1, padding=get_pad_tp(128, 128, [1, 1], [3, 3], [2, 2], [1, 1]), norm=norm, activation=activation), # [48, 256, 256] conv_block(nf * 2, nf, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [24, 256, 256] conv_block(nf, nf // 2, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation) ]) self.decoder_fine = nn.ModuleList([ # [192, 64, 64] conv_block(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [192, 64, 64] conv_block(nf * 4, nf * 4, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [96, 128, 128] upconv_block(nf * 4, nf * 2, kernel_size=3, stride=2, output_padding=1, padding=get_pad_tp(64, 64, [1, 1], [3, 3], [2, 2], [1, 1]), norm=norm, activation=activation), # [96, 128, 128] conv_block(nf * 4, nf * 2, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [48, 256, 256] upconv_block(nf * 2, nf, kernel_size=3, stride=2, output_padding=1, padding=get_pad_tp(128, 128, [1, 1], [3, 3], [2, 2], [1, 1]), norm=norm, activation=activation), # [48, 256, 256] conv_block(nf * 2, nf, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation), # [24, 256, 256] conv_block(nf, nf // 2, kernel_size=3, stride=1, padding=1, norm=norm, activation=activation) ]) self.out_coarse = nn.Sequential( conv_block(nf // 2, out_nc, 3, stride=1, padding=1, norm='none', activation='tanh')) self.out_fine = nn.Sequential( conv_block(nf // 2, out_nc, 3, stride=1, padding=1, norm='none', activation='tanh')) self.device = device
def __init__(self, in_nc=4, c_num=32, device=None): super(FineGenerator, self).__init__() self.conv_seq = nn.Sequential( # [4, 256, 256] conv_block(in_channels=in_nc, out_channels=c_num, kernel_size=5, stride=1, padding=2), # [cnum, 256, 256] conv_block(in_channels=c_num, out_channels=2 * c_num, kernel_size=3, stride=2, padding=1), # [cnum * 2, 128, 128] conv_block(in_channels=2 * c_num, out_channels=2 * c_num, kernel_size=3, stride=2, padding=1), # [cnum * 2, 64, 64] conv_block(in_channels=2 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=1), # [cnum * 4, 64, 64] conv_block(in_channels=4 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=1), # [cnum * 4, 64, 64] # dilation conv_block(in_channels=4 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=2, dilation=2), # [cnum * 4, 64, 64] conv_block(in_channels=4 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=4, dilation=4), # [cnum * 4, 64, 64] conv_block(in_channels=4 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=8, dilation=8), # [cnum * 4, 64, 64] conv_block(in_channels=4 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=16, dilation=16) # [cnum * 4, 64, 64] ) self.before_attn = nn.Sequential( # [4, 256, 256] conv_block(in_channels=in_nc, out_channels=c_num, kernel_size=5, stride=1, padding=2), # [cnum, 256, 256] conv_block(in_channels=c_num, out_channels=2 * c_num, kernel_size=3, stride=2, padding=1), # [cnum * 2, 128, 128] conv_block(in_channels=2 * c_num, out_channels=2 * c_num, kernel_size=3, stride=2, padding=1), # [cnum * 2, 64, 64] conv_block(in_channels=2 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=1), # [cnum * 4, 64, 64] conv_block(in_channels=4 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=1), # [cnum * 4, 64, 64] ) # Contextual attention self.contextual_attention = ContextualAttention(ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10, fuse=True, device=device) # [cnum * 4, 64, 64] self.after_attn = nn.Sequential( conv_block(in_channels=4 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=1), # [cnum * 4, 64, 64] conv_block(in_channels=4 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=1), # [cnum * 4, 64, 64] ) self.all_conv = nn.Sequential( # concatenated input [cnum * 8, 64, 64] conv_block(in_channels=8 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=1), # [cnum * 4, 64, 64] conv_block(in_channels=4 * c_num, out_channels=4 * c_num, kernel_size=3, stride=1, padding=1), # [cnum * 4, 64, 64] upconv_block(in_channels=4 * c_num, out_channels=2 * c_num, kernel_size=3, stride=2, padding=1, norm='in', output_padding=1), # [cnum * 2, 128, 128] conv_block(in_channels=2 * c_num, out_channels=2 * c_num, kernel_size=3, stride=1, padding=1), # [cnum * 2, 128, 128] upconv_block(in_channels=2 * c_num, out_channels=c_num, kernel_size=3, stride=2, padding=1, norm='in', output_padding=1), # [cnum, 256, 256] conv_block(in_channels=c_num, out_channels=c_num // 2, kernel_size=3, stride=1, padding=1), # [cnum * 8, 256, 256] conv_block(in_channels=c_num // 2, out_channels=3, kernel_size=3, stride=1, padding=1, activation='tanh'), )
def __init__(self, in_channels_m=1, in_channels_e=3, device=torch.device('cpu')): super(CoarseEncoder, self).__init__() self.madf_seq = nn.ModuleList([ # im - [3, 256, 256] # mask - [1, 256, 256] B.SimpleMADF(in_channels_m=in_channels_m, out_channels_m=2, in_channels_e=in_channels_e, out_channels_e=16, kernel_size=5, stride=2, padding=2), # im - [16, 128, 128] # mask - [2, 128, 128] B.SimpleMADF(in_channels_m=2, out_channels_m=4, in_channels_e=16, out_channels_e=32, kernel_size=3, stride=2, padding=1), # im - [32, 64, 64] # mask - [4, 64, 64] B.SimpleMADF(in_channels_m=4, out_channels_m=8, in_channels_e=32, out_channels_e=64, kernel_size=3, stride=2, padding=1), # im - [64, 32, 32] # mask - [8, 32, 32] B.SimpleMADF(in_channels_m=8, out_channels_m=16, in_channels_e=64, out_channels_e=128, kernel_size=3, stride=2, padding=1), # im - [128, 16, 16] # mask - [16, 16, 16] ]) self.up_seq = nn.ModuleList([ conv_block(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1, activation='relu', norm='none'), conv_block(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1, activation='relu', norm='none'), conv_block(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, activation='relu', norm='none'), conv_block(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, activation='relu', norm='none'), conv_block(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, activation='relu', norm='none') ])