class ResASPP(object):
    def __init__(self, aspp_type='aspp'):
        self.nn = NnKits()
        self.aspp = ASPP(aspp_type)

    def forward(self, input):
        with tf.compat.v1.variable_scope('encoder'):
            conv1 = self.nn.conv(input, 64, 7, 2)  # H/2  -   64D
            pool1 = self.nn.maxpool(conv1, 3)  # H/4  -   64D
            conv2 = self.nn.resblock(pool1, 64, 3)  # H/8  -   64D
            conv3 = self.nn.resblock(conv2, 128, 4)  # H/16 -  128D
            pool3 = self.nn.maxpool(conv3, 3)  # H/32 -  128D
            self.enc_feat = self.aspp.enc(pool3)  # H/32

        with tf.compat.v1.variable_scope('skips'):
            self.skip1 = conv1
            self.skip2 = pool1
            self.skip3 = conv2
            self.skip4 = conv3
class Resvgg(object):
    def __init__(self):
        self.nn = NnKits()

    def forward(self, input):
        with tf.compat.v1.variable_scope('encoder'):
            self.conv1 = self.nn.conv(input, 64, 7, 2)  # H/2  -   64D
            self.pool1 = self.nn.maxpool(self.conv1, 3)  # H/4  -   64D
            self.conv2 = self.nn.resblock(self.pool1, 64, 3)  # H/8  -  256D
            self.conv3 = self.nn.resblock(self.conv2, 128, 4)  # H/16 -  512D
            self.conv4 = self.nn.conv_block(self.conv3, 256, 3)  # H/32
            self.enc_feat = self.nn.conv_block(self.conv4, 512, 3)  # H/64

        with tf.compat.v1.variable_scope('skips'):
            self.skip1 = self.conv1
            self.skip2 = self.pool1
            self.skip3 = self.conv2
            self.skip4 = self.conv3
            self.skip5 = self.conv4