Ejemplo n.º 1
0
 def __init__(self,
              layer,
              buffer,
              in_channels,
              out_channels,
              kernel_size,
              stride=1,
              dilation=1,
              groups=1,
              bias=True,
              padding=0,
              norm='in',
              activation='elu',
              pad_type='zero'):
     super(RConv, self).__init__()
     self.buffer = buffer
     self.layer = layer
     self.conv_1 = conv_block(in_channels, out_channels, kernel_size,
                              stride, dilation, groups, bias, padding,
                              "none", "none", pad_type)
     self.conv_2 = conv_layer(out_channels,
                              out_channels,
                              kernel_size,
                              stride=1,
                              dilation=1,
                              groups=1)
     self.fusion = conv_block(out_channels * 2,
                              out_channels,
                              kernel_size=1,
                              norm=norm,
                              activation=activation)
     self.bn_act = nn.Sequential(nn.InstanceNorm2d(out_channels), nn.ELU())
Ejemplo n.º 2
0
    def __init__(self, layer, buffer, in_channels, out_channels, kernel_size=3, stride=2, dilation=1, groups=1,
                 bias=True, padding=1, output_padding=1, norm='in', activation='elu', pad_type='zero'):
        """

        :param layer:
        :param buffer:
        :param in_channels:
        :param out_channels:
        :param kernel_size:
        :param stride:
        :param dilation:
        :param groups:
        :param bias:
        :param padding:
        :param output_padding:
        :param norm:
        :param activation:
        :param pad_type:
        """
        super(RDeConv, self).__init__()

        self.buffer = buffer
        self.layer = layer
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                                         padding=padding, output_padding=output_padding, dilation=dilation, groups=groups, bias=bias)
        self.act = _activation(activation)
        self.norm = _norm(norm, out_channels)
        self.conv = conv_block(out_channels, out_channels, kernel_size, stride=1, bias=bias,
                               padding=1, pad_type=pad_type, norm=norm, activation=activation)
        self.fusion = conv_block(out_channels * 2, out_channels, kernel_size=1, norm=norm,
                                 activation=activation)
Ejemplo n.º 3
0
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 padding=0,
                 bias=True,
                 pad_type='zero',
                 norm='none',
                 activation='relu'):
        super(upconv_block, self).__init__()

        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1)
        self.act = _activation('relu')
        self.norm = _norm('in', out_channels)

        self.conv = conv_block(out_channels,
                               out_channels,
                               kernel_size,
                               stride,
                               bias=bias,
                               padding=padding,
                               pad_type=pad_type,
                               norm=norm,
                               activation=activation)
Ejemplo n.º 4
0
    def __init__(self,
                 nf,
                 res_blocks_num=8,
                 dmfb_blocks_num=4,
                 attention_mode="self"):
        """
        :param nf: number of channels
        :param res_blocks_num: number of Residual blocks
        :param dmfb_blocks_num: number of DMFB blocks
        :param attention_mode: Attention mode ("context" or "self")
        """

        super(FineBottleneck, self).__init__()
        self.attention = SelfAttention(nf, k=8)
        res_seq = []
        for _ in range(res_blocks_num):
            block = B.ResConv(nf, kernel_size=3, dilation=2,
                              padding=2)  # [192, 64, 64]
            res_seq.append(block)
        self.res_seq = nn.Sequential(*res_seq)

        dmfb_seq = []
        for _ in range(dmfb_blocks_num):
            block = DMFB(nf)  # [192, 64, 64]
            dmfb_seq.append(block)
        self.dmfb_seq = nn.Sequential(*dmfb_seq)

        self.out = nn.Sequential(
            conv_block(3 * nf,
                       nf,
                       kernel_size=1,
                       stride=1,
                       padding=0,
                       norm="in",
                       activation="elu",
                       pad_type="zero"),
            conv_block(nf,
                       nf,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm="in",
                       activation="elu",
                       pad_type="zero"))
Ejemplo n.º 5
0
    def __init__(self,
                 in_channels_r, out_channels_r, in_channels_u, out_channels,
                 kernel_size_in, kernel_size_out,
                 up_stride_in, stride_out,
                 up_padding_in, padding_out, output_padding=0,
                 activation_in='lrelu', activation_out='lrelu',
                 norm_in='bn', norm_out='none'):
        """
        u {l-1} and r {l} new should be of equal sizes to be concatenated

        :param in_channels_r: in channels for r {l}
        :param out_channels_r: out channels for r {l-1}
        :param in_channels_u: in channels for u {l-1}
        :param out_channels: out channels for (u {l-1}, r {l}) -> r {l-1}

        :param kernel_size_in: kernel size for r {l}) -> r {l} new
        :param kernel_size_out: kernel size for (u {l-1}, r {l}) -> r {l-1}

        :param up_stride_in: stride for transposed conv r {l}) -> r {l} new
        :param stride_out: stride  for (u {l-1}, r {l}) -> r {l-1}

        :param up_padding_in: padding for transposed convolution r {l} -> r {l} new
        :param output_padding: output padding for transposed convolution r {l} -> r {l} new
        :param padding_out: padding for (u {l-1}, r {l}) -> r {l-1}

        :param activation_in: activation layer for r {l}) -> r {l} new
        :param activation_out: activation layer for (u {l-1}, r {l}) -> r {l-1}

        :param norm_in: normalization layer for r {l}) -> r {l} new
        :param norm_out: normalization layer for (u {l-1}, r {l}) -> r {l-1}
        """
        super(RecovecyBlock, self).__init__()

        self.in_upconv = upconv_block(
            in_channels=in_channels_r,
            out_channels=out_channels_r,
            kernel_size=kernel_size_in,
            stride=up_stride_in,
            padding=up_padding_in,
            output_padding=output_padding,
            norm=norm_in,
            activation=activation_in
        )

        self.out_conv = conv_block(
            in_channels=out_channels_r + in_channels_u,
            out_channels=out_channels,
            kernel_size=kernel_size_out,
            stride=stride_out,
            padding=padding_out,
            norm=norm_out,
            activation=activation_out
        )
Ejemplo n.º 6
0
    def __init__(self,
                 in_channels_m, out_channels_m,
                 in_channels_e, out_channels_e,
                 kernel_size, stride, padding,
                 activation='relu', norm='none'):
        super(SimpleMADF, self).__init__()

        self.conv_m = conv_block(
            in_channels=in_channels_m,
            out_channels=out_channels_m,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            activation=activation,
            norm='none')

        self.conv_e = conv_block(
            in_channels=in_channels_e + in_channels_m,
            out_channels=out_channels_e,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            activation=activation,
            norm=norm)
Ejemplo n.º 7
0
    def __init__(self,
                 in_channels_1, in_channels_2, out_channels,
                 kernel_size_1, kernel_size_2,
                 stride_1, up_stride_2,
                 padding_1, up_padding_2, output_padding=0,
                 activation_in='relu', activation_out='lrelu',
                 norm_in='bn', norm_out='none'):
        """
        1 convolution - from f {l-1}{r-1}
        2 transposed convolution - from f {l} {k}
        3 convolution - the one to be summed with result of f {l}{k} (2) convolution
        4 convolution - the one to be producted with result of f {l}{k} (2) convolution

        :param in_channels_1: Input channels of the 1st convolution layer
        :param in_channels_2: Input channels of the 2st convolution layer
        :param out_channels: Output channels of the 1st, 2st, 3rd, 4th convolution layers

        :param kernel_size_1: Kernel size of 1st convolution layer
        :param kernel_size_2: Kernel size of 2nd transposed convolution layer

        :param stride_1: Stride of 1st convolution layer
        :param up_stride_2: Stride of 2nd transposed convolution layer

        :param padding_1: Padding of 1st convolution layer
        :param up_padding_2: Padding of 2nd transposed convolution layer
        :param output_padding: Output padding of 2nd transposed convolution layer

        :param activation_in: Activation layer of 1st convolution layer
        :param activation_out: Activation layer 2nd transposed convolution layer

        :param norm_in: Normalization layer of 1st convolution layer
        :param norm_out: Normalization layer of 2nd transposed convolution layer
        """

        super(RefinementBlock, self).__init__()

        self.conv_1 = conv_block(
            in_channels=in_channels_1,
            out_channels=out_channels,
            kernel_size=kernel_size_1,
            stride=stride_1,
            padding=padding_1,
            norm='none',
            activation=activation_in
        )

        self.upconv_2 = upconv_block(
            in_channels=in_channels_2,
            out_channels=out_channels,
            kernel_size=kernel_size_2,
            stride=up_stride_2,
            padding=up_padding_2,
            output_padding=output_padding,
            norm=norm_in,
            activation='none'
        )

        self.conv_3 = conv_block(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            norm='none',
            activation='none'
        )

        self.conv_4 = conv_block(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            norm='none',
            activation='none'
        )

        self.out_act = _activation(act_type=activation_out)
        self.out_norm = _norm(norm_type=norm_out, channels=out_channels)
Ejemplo n.º 8
0
    def __init__(self,
                 in_channels_m, out_channels_m,
                 in_channels_e, out_channels_e,
                 kernel_size_m, kernel_size_e,
                 stride_m, stride_e,
                 padding_m, padding_e,
                 activation_m='relu', activation_e='relu',
                 norm_m='none', norm_e='bn',
                 device=torch.device('cpu')):
        """
        :param in_channels_m: input channels of mask layer - m {l-1}
        :param out_channels_m: output channels of mask layer - m {l}
        :param in_channels_e: input chanels of image layer - e {l-1}
        :param out_channels_e: output channels of image layer - e {l}

        :param kernel_size_m: kernel size for transformation m {l-1} -> m {l}
        :param kernel_size_e: kernel size for transformation e {l-1} -> e {l}
            and for kernel creation m {l} -> kernel for e {l} convolution

        :param stride_m: stride for m {l-1} -> m {l} transformation
        :param stride_e: stride for e {l-1} -> e {l} transformation

        :param padding_m: padding for m {l-1} -> m {l} transformation
        :param padding_e: padding for e {l-1} -> e {l} transformation

        :param activation_m: activation_m for m {l-1} -> m {l} transformation
        :param activation_e: activation_e for e {l-1} -> e {l} transformation
        """
        super(MADF, self).__init__()
        self.in_channels_m = in_channels_m
        self.out_channels_m = out_channels_m
        self.in_channels_e = in_channels_e
        self.out_channels_e = out_channels_e

        self.kernel_size_e = kernel_size_e
        self.kernel_size_m = kernel_size_m

        self.padding_e = padding_e
        self.padding_m = padding_m

        self.stride_e = stride_e
        self.stride_m = stride_m

        self.activation_m = activation_m
        self.activation_e = activation_e

        self.norm_m = norm_m
        self.norm_e = norm_e

        self.conv_m = conv_block(
            in_channels=in_channels_m,
            out_channels=out_channels_m,
            kernel_size=kernel_size_m,
            stride=stride_m,
            padding=padding_m,
            activation=activation_m,
            norm=norm_m)

        self.conv_filters = conv_block(
            in_channels=out_channels_m,
            out_channels=in_channels_e * kernel_size_e *
            out_channels_e * kernel_size_e,
            kernel_size=1,
            stride=1,
            padding=0,
            activation="none",
            norm='none')

        self.device = device
        self.activation_e = _activation(activation_e)
        self.norm = _norm(norm_e, out_channels_e)
Ejemplo n.º 9
0
    def __init__(self,
                 in_nc=4,
                 out_nc=3,
                 res_blocks_num=8,
                 c_num=48,
                 dmfb_blocks_num=8,
                 norm='in',
                 activation='relu',
                 device=None):
        super(FineGenerator, self).__init__()

        self.encoder = nn.ModuleList([
            # [4, 256, 256]
            conv_block(in_channels=in_nc,
                       out_channels=c_num,
                       kernel_size=5,
                       stride=1,
                       padding=2,
                       norm=norm,
                       activation=activation),
            # [c_num, 256, 256] -> decoder
            conv_block(in_channels=c_num,
                       out_channels=2 * c_num,
                       kernel_size=3,
                       stride=2,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 2, 128, 128]
            conv_block(in_channels=2 * c_num,
                       out_channels=2 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 2, 128, 128] -> decoder
            conv_block(in_channels=2 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=2,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 4, 64, 64]
            conv_block(in_channels=4 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 4, 64, 64]
            conv_block(in_channels=4 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 4, 64, 64]
        ])

        self.fine_bn = FineBottleneck(c_num=4 * c_num,
                                      dmfb_blocks_num=dmfb_blocks_num,
                                      res_blocks_num=res_blocks_num,
                                      device=device)

        self.decoder = nn.ModuleList([
            # [c_num * 4, 64, 64]
            conv_block(in_channels=c_num * 4,
                       out_channels=c_num * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1),
            # [c_num * 4, 64, 64]
            conv_block(in_channels=c_num * 4,
                       out_channels=c_num * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 4, 64, 64]
            upconv_block(in_channels=c_num * 4,
                         out_channels=c_num * 2,
                         kernel_size=3,
                         stride=2,
                         output_padding=1,
                         padding=get_pad_tp(64, 64, [1, 1], [3, 3], [2, 2],
                                            [1, 1]),
                         norm=norm,
                         activation=activation),
            # encoder 3 + skip [c_num * 2, 128, 128] -> decoder 4
            # [c_num * 4, 128, 128]
            conv_block(in_channels=c_num * 4,
                       out_channels=c_num * 2,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 2, 128, 128]
            upconv_block(in_channels=c_num * 2,
                         out_channels=c_num,
                         kernel_size=3,
                         stride=2,
                         output_padding=1,
                         padding=get_pad_tp(128, 128, [1, 1], [3, 3], [2, 2],
                                            [1, 1]),
                         norm=norm,
                         activation=activation),
            # encoder 1 + skip [c_num, 256, 256] -> decoder 6
            # [c_num, 256, 256]
            conv_block(in_channels=c_num * 2,
                       out_channels=c_num // 2,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num//2, 256, 256]
            conv_block(in_channels=c_num // 2,
                       out_channels=out_nc,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm='none',
                       activation='tanh'),
            # [out_nc, 256, 256]
        ])
Ejemplo n.º 10
0
    def __init__(self,
                 c_num,
                 res_blocks_num=8,
                 dmfb_blocks_num=8,
                 activateion="relu",
                 norm="in",
                 device=None):
        super(FineBottleneck, self).__init__()

        dmfb_blocks = []
        for i in range(dmfb_blocks_num):
            # [c_num, 64, 64]
            dmfb_blocks.append(DMFB(in_channels=c_num))

        self.dmfb_seq = nn.Sequential(*dmfb_blocks)

        res_blocks = []
        for i in range(res_blocks_num):
            # [c_num, 64, 64]
            block = B.ResConv(channels=c_num,
                              kernel_size=3,
                              dilation=2,
                              padding=2)
            res_blocks.append(block)

        self.res_seq = nn.Sequential(*res_blocks)

        # Contextual attention
        self.contextual_attention = ContextualAttention(ksize=3,
                                                        stride=1,
                                                        rate=2,
                                                        fuse_k=3,
                                                        softmax_scale=10,
                                                        fuse=True,
                                                        device=device)

        self.hypergraph = HypergraphConv(in_channels=c_num,
                                         out_channels=c_num,
                                         filters=256,
                                         edges=256,
                                         height=64,
                                         width=64)

        self.out_1 = nn.Sequential(
            conv_block(in_channels=3 * c_num,
                       out_channels=c_num,
                       kernel_size=1,
                       stride=1,
                       padding=0,
                       norm=norm,
                       activation=activateion,
                       pad_type="zero"),
            conv_block(in_channels=c_num,
                       out_channels=c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activateion,
                       pad_type="zero"))

        self.out = nn.Sequential(
            conv_block(in_channels=2 * c_num,
                       out_channels=c_num,
                       kernel_size=1,
                       stride=1,
                       padding=0,
                       norm=norm,
                       activation=activateion,
                       pad_type="zero"),
            conv_block(in_channels=c_num,
                       out_channels=c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activateion,
                       pad_type="zero"))
Ejemplo n.º 11
0
    def __init__(self,
                 in_nc=4,
                 c_num=48,
                 out_nc=3,
                 res_blocks_num=8,
                 norm="in",
                 activation="relu"):
        super(CoarseGenerator, self).__init__()

        self.encoder_coarse = nn.ModuleList([
            # [4, 256, 256]
            conv_block(in_nc,
                       c_num,
                       kernel_size=5,
                       stride=1,
                       padding=2,
                       norm=norm,
                       activation=activation),
            # [c_num, 256, 256] -> decoder_6
            conv_block(c_num,
                       c_num * 2,
                       kernel_size=3,
                       stride=2,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 2, 128, 128]
            conv_block(c_num * 2,
                       c_num * 2,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 2, 128, 128] -> decoder_4
            conv_block(c_num * 2,
                       c_num * 4,
                       kernel_size=3,
                       stride=2,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 4, 64, 64]
            conv_block(c_num * 4,
                       c_num * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 4, 64, 64]
            conv_block(c_num * 4,
                       c_num * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation)
            # [c_num * 4, 64, 64]]
        ])

        blocks = []
        for _ in range(res_blocks_num):
            block = B.ResConv(4 * c_num, kernel_size=3, dilation=2,
                              padding=2)  # [192, 64, 64]
            blocks.append(block)

        self.coarse_bn = nn.Sequential(*blocks)

        self.decoder_coarse = nn.ModuleList([
            conv_block(c_num * 4,
                       c_num * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 4, 64, 64]
            conv_block(c_num * 4,
                       c_num * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 4, 64, 64]
            upconv_block(c_num * 4,
                         c_num * 2,
                         kernel_size=3,
                         stride=2,
                         output_padding=1,
                         padding=get_pad_tp(64, 64, [1, 1], [3, 3], [2, 2],
                                            [1, 1]),
                         norm=norm,
                         activation=activation),
            # [c_num * 2, 128, 128] + skip [c_num * 2, 128, 128]
            conv_block(c_num * 4,
                       c_num * 2,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num * 2, 128, 128]
            upconv_block(c_num * 2,
                         c_num,
                         kernel_size=3,
                         stride=2,
                         output_padding=1,
                         padding=get_pad_tp(128, 128, [1, 1], [3, 3], [2, 2],
                                            [1, 1]),
                         norm=norm,
                         activation=activation),
            # [c_num, 256, 256] + skip [c_num, 256, 256]
            conv_block(c_num * 2,
                       c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [c_num, 256, 256]
            conv_block(c_num,
                       c_num // 2,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation)
            # [c_num // 2, 256, 256]
        ])

        self.out_coarse = nn.Sequential(
            conv_block(c_num // 2,
                       out_nc,
                       3,
                       stride=1,
                       padding=1,
                       norm='none',
                       activation='tanh'))
Ejemplo n.º 12
0
    def __init__(self,
                 in_nc=4,
                 out_nc=3,
                 nf=48,
                 norm="in",
                 activation="relu",
                 res_blocks_num=8,
                 dmfb_block_num=4,
                 device=torch.device('cuda')):
        """
        :param in_nc: in channels number
        :param out_nc: out channels number
        :param nf: number of intermediete features
        :param res_blocks_num: number of Residual blocks in fine and coarse roots
        :param dmfb_block_num: nunumber of DMFB blocks in fine route
        """

        super(InpaintingGenerator, self).__init__()
        # [4, 256, 256]
        self.encoder_coarse = nn.ModuleList([
            # [48, 256, 256] -> decoder_6
            conv_block(in_nc,
                       nf,
                       kernel_size=5,
                       stride=1,
                       padding=2,
                       norm=norm,
                       activation=activation),
            # [96, 128, 128]
            conv_block(nf,
                       nf * 2,
                       kernel_size=3,
                       stride=2,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [96, 128, 128] -> decoder_4
            conv_block(nf * 2,
                       nf * 2,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [192, 64, 64]
            conv_block(nf * 2,
                       nf * 4,
                       kernel_size=3,
                       stride=2,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [192, 64, 64]
            conv_block(nf * 4,
                       nf * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [192, 64, 64]
            conv_block(nf * 4,
                       nf * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation)
        ])
        self.encoder_fine = nn.ModuleList([
            # [48, 256, 256] -> decoder_6
            conv_block(in_nc,
                       nf,
                       kernel_size=5,
                       stride=1,
                       padding=2,
                       norm=norm,
                       activation=activation),
            # [96, 128, 128]
            conv_block(nf,
                       nf * 2,
                       kernel_size=3,
                       stride=2,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [96, 128, 128] -> decoder_4
            conv_block(nf * 2,
                       nf * 2,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [192, 64, 64]
            conv_block(nf * 2,
                       nf * 4,
                       kernel_size=3,
                       stride=2,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [192, 64, 64]
            conv_block(nf * 4,
                       nf * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [192, 64, 64]
            conv_block(nf * 4,
                       nf * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation)
        ])

        blocks = []
        for _ in range(res_blocks_num):
            block = B.ResConv(4 * nf, kernel_size=3, dilation=2,
                              padding=2)  # [192, 64, 64]
            blocks.append(block)

        self.coarse = nn.Sequential(*blocks)
        self.fine = FineBottleneck(4 * nf,
                                   res_blocks_num,
                                   dmfb_block_num,
                                   attention_mode="self")

        self.decoder_coarse = nn.ModuleList([
            # [192, 64, 64]
            conv_block(nf * 4,
                       nf * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [192, 64, 64]
            conv_block(nf * 4,
                       nf * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [96, 128, 128]
            upconv_block(nf * 4,
                         nf * 2,
                         kernel_size=3,
                         stride=2,
                         output_padding=1,
                         padding=get_pad_tp(64, 64, [1, 1], [3, 3], [2, 2],
                                            [1, 1]),
                         norm=norm,
                         activation=activation),
            # [96, 128, 128]
            conv_block(nf * 4,
                       nf * 2,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [48, 256, 256]
            upconv_block(nf * 2,
                         nf,
                         kernel_size=3,
                         stride=2,
                         output_padding=1,
                         padding=get_pad_tp(128, 128, [1, 1], [3, 3], [2, 2],
                                            [1, 1]),
                         norm=norm,
                         activation=activation),
            # [48, 256, 256]
            conv_block(nf * 2,
                       nf,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [24, 256, 256]
            conv_block(nf,
                       nf // 2,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation)
        ])
        self.decoder_fine = nn.ModuleList([
            # [192, 64, 64]
            conv_block(nf * 4,
                       nf * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [192, 64, 64]
            conv_block(nf * 4,
                       nf * 4,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [96, 128, 128]
            upconv_block(nf * 4,
                         nf * 2,
                         kernel_size=3,
                         stride=2,
                         output_padding=1,
                         padding=get_pad_tp(64, 64, [1, 1], [3, 3], [2, 2],
                                            [1, 1]),
                         norm=norm,
                         activation=activation),
            # [96, 128, 128]
            conv_block(nf * 4,
                       nf * 2,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [48, 256, 256]
            upconv_block(nf * 2,
                         nf,
                         kernel_size=3,
                         stride=2,
                         output_padding=1,
                         padding=get_pad_tp(128, 128, [1, 1], [3, 3], [2, 2],
                                            [1, 1]),
                         norm=norm,
                         activation=activation),
            # [48, 256, 256]
            conv_block(nf * 2,
                       nf,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation),
            # [24, 256, 256]
            conv_block(nf,
                       nf // 2,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       norm=norm,
                       activation=activation)
        ])

        self.out_coarse = nn.Sequential(
            conv_block(nf // 2,
                       out_nc,
                       3,
                       stride=1,
                       padding=1,
                       norm='none',
                       activation='tanh'))
        self.out_fine = nn.Sequential(
            conv_block(nf // 2,
                       out_nc,
                       3,
                       stride=1,
                       padding=1,
                       norm='none',
                       activation='tanh'))

        self.device = device
Ejemplo n.º 13
0
    def __init__(self, in_nc=4, c_num=32, device=None):
        super(FineGenerator, self).__init__()

        self.conv_seq = nn.Sequential(
            # [4, 256, 256]
            conv_block(in_channels=in_nc,
                       out_channels=c_num,
                       kernel_size=5,
                       stride=1,
                       padding=2),
            # [cnum, 256, 256]
            conv_block(in_channels=c_num,
                       out_channels=2 * c_num,
                       kernel_size=3,
                       stride=2,
                       padding=1),
            # [cnum * 2, 128, 128]
            conv_block(in_channels=2 * c_num,
                       out_channels=2 * c_num,
                       kernel_size=3,
                       stride=2,
                       padding=1),
            # [cnum * 2, 64, 64]
            conv_block(in_channels=2 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1),
            # [cnum * 4, 64, 64]
            conv_block(in_channels=4 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1),
            # [cnum * 4, 64, 64]

            # dilation
            conv_block(in_channels=4 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=2,
                       dilation=2),
            # [cnum * 4, 64, 64]
            conv_block(in_channels=4 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=4,
                       dilation=4),
            # [cnum * 4, 64, 64]
            conv_block(in_channels=4 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=8,
                       dilation=8),
            # [cnum * 4, 64, 64]
            conv_block(in_channels=4 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=16,
                       dilation=16)
            # [cnum * 4, 64, 64]
        )

        self.before_attn = nn.Sequential(
            # [4, 256, 256]
            conv_block(in_channels=in_nc,
                       out_channels=c_num,
                       kernel_size=5,
                       stride=1,
                       padding=2),
            # [cnum, 256, 256]
            conv_block(in_channels=c_num,
                       out_channels=2 * c_num,
                       kernel_size=3,
                       stride=2,
                       padding=1),
            # [cnum * 2, 128, 128]
            conv_block(in_channels=2 * c_num,
                       out_channels=2 * c_num,
                       kernel_size=3,
                       stride=2,
                       padding=1),
            # [cnum * 2, 64, 64]
            conv_block(in_channels=2 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1),
            # [cnum * 4, 64, 64]
            conv_block(in_channels=4 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1),
            # [cnum * 4, 64, 64]
        )

        # Contextual attention
        self.contextual_attention = ContextualAttention(ksize=3,
                                                        stride=1,
                                                        rate=2,
                                                        fuse_k=3,
                                                        softmax_scale=10,
                                                        fuse=True,
                                                        device=device)
        # [cnum * 4, 64, 64]

        self.after_attn = nn.Sequential(
            conv_block(in_channels=4 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1),
            # [cnum * 4, 64, 64]
            conv_block(in_channels=4 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1),
            # [cnum * 4, 64, 64]
        )

        self.all_conv = nn.Sequential(
            # concatenated input [cnum * 8, 64, 64]
            conv_block(in_channels=8 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1),
            # [cnum * 4, 64, 64]
            conv_block(in_channels=4 * c_num,
                       out_channels=4 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1),
            # [cnum * 4, 64, 64]
            upconv_block(in_channels=4 * c_num,
                         out_channels=2 * c_num,
                         kernel_size=3,
                         stride=2,
                         padding=1,
                         norm='in',
                         output_padding=1),
            # [cnum * 2, 128, 128]
            conv_block(in_channels=2 * c_num,
                       out_channels=2 * c_num,
                       kernel_size=3,
                       stride=1,
                       padding=1),
            # [cnum * 2, 128, 128]
            upconv_block(in_channels=2 * c_num,
                         out_channels=c_num,
                         kernel_size=3,
                         stride=2,
                         padding=1,
                         norm='in',
                         output_padding=1),
            # [cnum, 256, 256]
            conv_block(in_channels=c_num,
                       out_channels=c_num // 2,
                       kernel_size=3,
                       stride=1,
                       padding=1),
            # [cnum * 8, 256, 256]
            conv_block(in_channels=c_num // 2,
                       out_channels=3,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       activation='tanh'),
        )
Ejemplo n.º 14
0
    def __init__(self,
                 in_channels_m=1,
                 in_channels_e=3,
                 device=torch.device('cpu')):
        super(CoarseEncoder, self).__init__()

        self.madf_seq = nn.ModuleList([
            # im - [3, 256, 256]
            # mask - [1, 256, 256]
            B.SimpleMADF(in_channels_m=in_channels_m,
                         out_channels_m=2,
                         in_channels_e=in_channels_e,
                         out_channels_e=16,
                         kernel_size=5,
                         stride=2,
                         padding=2),
            # im - [16, 128, 128]
            # mask - [2, 128, 128]
            B.SimpleMADF(in_channels_m=2,
                         out_channels_m=4,
                         in_channels_e=16,
                         out_channels_e=32,
                         kernel_size=3,
                         stride=2,
                         padding=1),
            # im - [32, 64, 64]
            # mask - [4, 64, 64]
            B.SimpleMADF(in_channels_m=4,
                         out_channels_m=8,
                         in_channels_e=32,
                         out_channels_e=64,
                         kernel_size=3,
                         stride=2,
                         padding=1),
            # im - [64, 32, 32]
            # mask - [8, 32, 32]
            B.SimpleMADF(in_channels_m=8,
                         out_channels_m=16,
                         in_channels_e=64,
                         out_channels_e=128,
                         kernel_size=3,
                         stride=2,
                         padding=1),
            # im - [128, 16, 16]
            # mask - [16, 16, 16]
        ])

        self.up_seq = nn.ModuleList([
            conv_block(in_channels=3,
                       out_channels=16,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       activation='relu',
                       norm='none'),
            conv_block(in_channels=16,
                       out_channels=16,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       activation='relu',
                       norm='none'),
            conv_block(in_channels=32,
                       out_channels=32,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       activation='relu',
                       norm='none'),
            conv_block(in_channels=64,
                       out_channels=64,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       activation='relu',
                       norm='none'),
            conv_block(in_channels=128,
                       out_channels=128,
                       kernel_size=3,
                       stride=1,
                       padding=1,
                       activation='relu',
                       norm='none')
        ])