def __init__(self, upscale_factor, num_blocks=36, **kwargs): super(EDSR, self).__init__() if upscale_factor % 2 != 0 and upscale_factor != 1: error("Upscaling factor must be 1 or a multiple of 2") raise SystemExit(1) self.num_blocks = num_blocks self.upscale_factor = upscale_factor self.scales = [self.upscale_factor] # Projection from image space. self.add_module('init_conv', Conv2d(3, 256, 3)) # Change to (4,256,3) fir alpha channel # Backbone arch = OrderedDict() for i in range(1, self.num_blocks + 1): arch['resblock_%d' % i] = ResidualBlock(block_type.CRC, 'RELU', 256, res_factor=0.1) arch['final_conv'] = Conv2d(256, 256, 3) self.add_module('residual', nn.Sequential(arch)) # Upsampling and reconstruction self.add_module('upsampler', PixelShuffleUpsampler(upscale_factor, 256)) self.add_module('reconst', nn.Sequential( OrderedDict([ ('reconst_conv0', Conv2d(256, 3, 3)) ]))) # Change to (256,4,3) for alpha channel
def forward(self, x, scale=None, blend=1): if scale is not None and scale != self.upscale_factor: error("Invalid upscaling factor: choose one of: {}".format( [self.upscale_factor])) init_conv = self.init_conv(x) residual = self.residual(init_conv) output = init_conv + residual output = self.upsampler(output) output = self.reconst(output) return output
def forward(self, x, upscale_factor=None, blend=1.0): # print("******************forward******************") # print("X shape", x.shape) if upscale_factor is None: upscale_factor = self.max_scale else: valid_upscale_factors = [ 2**(i + 1) for i in range(self.n_pyramids) ] if upscale_factor not in valid_upscale_factors: error("Invalid upscaling factor {}: choose one of: {}".format( upscale_factor, valid_upscale_factors)) raise SystemExit(1) feats = self.get_init_conv(log2(upscale_factor))(x) for s in range(1, int(log2(upscale_factor)) + 1): if self.residual_denseblock: feats = getattr(self, 'pyramid_residual_%d' % s)(feats) + feats else: feats = getattr(self, 'pyramid_residual_%d' % s)(feats) feats = getattr(self, 'pyramid_residual_%d_residual_upsampler' % s)(feats) # reconst residual image if reached desired scale / # use intermediate as base_img / use blend and s is one step lower than desired scale if 2**s == upscale_factor or (blend != 1.0 and 2**(s + 1) == upscale_factor): tmp = getattr(self, 'reconst_%d' % s)(feats) # if using blend, upsample the second last feature via bilinear upsampling if (blend != 1.0 and s == self.current_scale_idx): base_img = nn.functional.upsample(tmp, scale_factor=2, mode='bilinear', align_corners=True) if 2**s == upscale_factor: if (blend != 1.0) and s == self.current_scale_idx + 1: tmp = tmp * blend + (1 - blend) * base_img output = tmp # print("Output shape", output.shape) return output
def parse_args(): parser = ArgumentParser(description='Evaluation') parser.add_argument( '-i', '--input', help='High-resolution images, either list or path to folder', type=str, nargs='*', required=True, default=[]) parser.add_argument( '-t', '--target', help='Super-resolution images, either list or path to folder', type=str, nargs='*', required=True, default=[]) parser.add_argument( '-s', '--scale', help='upscale ratio e.g. 2, 4 or 8', type=int, required=True) args = parser.parse_args() args.input = get_filenames(args.input, IMG_EXTENSIONS) args.target = get_filenames(args.target, IMG_EXTENSIONS) if not len(args.input): error("Did not find images in: {}".format(args.input)) if len(args.input) != len(args.target): error("Inconsistent number of images between 'input' and 'target'") return args