def shift_cPSNR(sr, hr, hr_map, border_w=3): """ cPSNR score adjusted for registration errors. Computes the max cPSNR score across shifts of up to `border_w` pixels. Args: sr: np.ndarray (n, m), super-resolved image hr: np.ndarray (n, m), high-res ground-truth image hr_map: np.ndarray (n, m), high-res status map border_w: int, width of the trimming border around `hr` and `hr_map` Returns: max_cPSNR: float, score of the super-resolved image """ size = sr.shape[1] - (2 * border_w) # patch size sr = get_patch(img=sr, x=border_w, y=border_w, size=size) pos = list( itertools.product(range(2 * border_w + 1), range(2 * border_w + 1))) iter_hr = patch_iterator(img=hr, positions=pos, size=size) iter_hr_map = patch_iterator(img=hr_map, positions=pos, size=size) site_cPSNR = np.array([ cPSNR(sr, hr, hr_map) for hr, hr_map in tqdm(zip(iter_hr, iter_hr_map), disable=(len(sr.shape) == 2)) ]) max_cPSNR = np.max(site_cPSNR, axis=0) return max_cPSNR
def patch_iterator(img, positions, size): """Iterator across square patches of `img` located in `positions`.""" for x, y in positions: yield get_patch(img=img, x=x, y=y, size=size)