Example #1
0
    def __init__(self, **kwargs):
        super(Net, self).__init__()

        channel_cnt = kwargs.get("channel_cnt", 64)
        scale = kwargs.get("scale")
        multi_scale = kwargs.get("multi_scale")
        group = kwargs.get("group", 1)

        self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
        self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)

        self.entry = nn.Conv2d(3, channel_cnt, 3, 1, 1)

        self.b1 = Block(channel_cnt, channel_cnt)
        self.b2 = Block(channel_cnt, channel_cnt)
        self.b3 = Block(channel_cnt, channel_cnt)
        self.c1 = ops.BasicBlock(channel_cnt * 2, channel_cnt, 1, 1, 0)
        self.c2 = ops.BasicBlock(channel_cnt * 3, channel_cnt, 1, 1, 0)
        self.c3 = ops.BasicBlock(channel_cnt * 4, channel_cnt, 1, 1, 0)

        self.upsample = ops.UpsampleBlock(channel_cnt,
                                          scale=scale,
                                          multi_scale=multi_scale,
                                          group=group)
        self.exit = nn.Conv2d(channel_cnt, 3, 3, 1, 1)
    def __init__(self, upscale=4):
        super(CARN_blanced_attention, self).__init__()

        # scale = kwargs.get("scale")
        # multi_scale = kwargs.get("multi_scale")
        # group = kwargs.get("group", 1)

        self.scale = upscale
        multi_scale = True
        group = 1

        self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True)
        self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False)

        self.entry = nn.Conv2d(3, 64, 3, 1, 1)

        self.b1 = Block(64, 64)
        self.b2 = Block(64, 64)
        self.b3 = Block(64, 64)
        self.c1 = nn.Sequential(ops.BasicBlock(64 * 2, 64, 1, 1, 0),BlancedAttention(64))
        self.c2 = nn.Sequential(ops.BasicBlock(64 * 3, 64, 1, 1, 0),BlancedAttention(64))
        self.c3 = nn.Sequential(ops.BasicBlock(64 * 4, 64, 1, 1, 0),BlancedAttention(64))

        self.upsample = ops.UpsampleBlock(64, scale=upscale,
                                          multi_scale=multi_scale,
                                          group=group)
        self.exit = nn.Conv2d(64, 3, 3, 1, 1)