def synthesized_met(self, request): img, model = request.param.split('-') # make the images really small so nothing takes as long if img == 'rgb': img = po.load_images(op.join(DATA_DIR, 'color_wheel.jpg'), False).to(DEVICE) img = img[..., :16, :16] else: img = po.load_images(op.join(DATA_DIR, 'nuts.pgm')).to(DEVICE) img = img[..., :16, :16] if model == 'class': # height=1 and order=0 to limit the time this takes, and then we # only return one of the tensors so that everything is easy for # plotting code to figure out (if we downsampled and were on an # RGB image, we'd have a tensor of shape [1, 9, h, w], because # we'd have the residuals and one filter output for each channel, # and our code doesn't know how to handle that) class SPyr(po.simul.Steerable_Pyramid_Freq): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, *args, **kwargs): return super().forward(*args, **kwargs)[(0, 0)] model = SPyr(img.shape[-2:], height=1, order=0).to(DEVICE) else: # to serve as a metric, need to return a single value, but SSIM # will return a separate value for each RGB channel def rgb_ssim(*args, **kwargs): return po.metric.ssim(*args, **kwargs).mean() model = rgb_ssim met = po.synth.Metamer(img, model) met.synthesize(max_iter=2, store_progress=True) return met
def test_msssim_analysis(self, msssim_images): # True values are defined by https://ece.uwaterloo.ca/~z70wang/research/iwssim/msssim.zip true_values = torch.tensor( [1.0000000, 0.9112161, 0.7699084, 0.8785111, 0.9488805], device=DEVICE) computed_values = torch.zeros_like(true_values) base_img = po.load_images(op.join(msssim_images, "samp0.tiff")).to(DEVICE) for i in range(len(true_values)): other_img = po.load_images(op.join(msssim_images, f"samp{i}.tiff")).to(DEVICE) computed_values[i] = po.metric.ms_ssim(base_img, other_img) assert torch.allclose(true_values, computed_values)
def img(self, request): im, shape = request.param.split('-') img = po.load_images(op.join(DATA_DIR, f'{im}.pgm')).to(DEVICE) if shape == '224': img = img[..., :224, :224] elif shape == '128_1': img = img[..., :128, :] elif shape == '128_2': img = img[..., :128] return img
def synthesized_mad(self, request): # make the images really small so nothing takes as long if request.param == 'rgb': img = po.load_images(op.join(DATA_DIR, 'color_wheel.jpg'), False).to(DEVICE) img = img[..., :16, :16] else: img = po.load_images(op.join(DATA_DIR, 'nuts.pgm')).to(DEVICE) img = img[..., :16, :16] model1 = po.simul.Identity().to(DEVICE) # to serve as a metric, need to return a single value, but SSIM # will return a separate value for each RGB channel def rgb_ssim(*args, **kwargs): return po.metric.ssim(*args, **kwargs).mean() model2 = rgb_ssim mad = po.synth.MADCompetition(img, model1, model2) mad.synthesize('model_1_min', max_iter=2, store_progress=True) return mad
def test_ssim_analysis(self, weighted, other_img, ssim_images, ssim_analysis, ssim_base_img): mat_type = {True: 'weighted', False: 'standard'}[weighted] other = po.load_images(op.join(ssim_images, f"samp{other_img}.tif")).to(DEVICE) # dynamic range is 1 for these images, because po.load_images # automatically re-ranges them. They were comptued with # dynamic_range=255 in MATLAB, and by correctly setting this value, # that should be corrected for plen_val = po.metric.ssim(ssim_base_img, other, weighted) mat_val = torch.tensor( ssim_analysis[mat_type][f'samp{other_img}'].astype(np.float32), device=DEVICE) # float32 precision is ~1e-6 (see `np.finfo(np.float32)`), and the # errors increase through multiplication and other operations. print(plen_val - mat_val, plen_val, mat_val) assert torch.allclose(plen_val, mat_val.view_as(plen_val), atol=1e-5)
def color_img(): img = po.load_images(op.join(DATA_DIR, 'color_wheel.jpg'), as_gray=False).to(DEVICE) return img[..., :256, :256]
def einstein_img(): return po.load_images(op.join(DATA_DIR, 'einstein.pgm')).to(DEVICE)
def curie_img(): return po.load_images(op.join(DATA_DIR, 'curie.pgm')).to(DEVICE)
def ssim_base_img(self, ssim_images, ssim_analysis): return po.load_images(op.join(ssim_images, ssim_analysis['base_img'])).to(DEVICE)