def write_image_maml(image_resolution, mask, model, model_input, gt, model_output, writer, total_steps, inner=False, prefix='train_'): if mask is None: gt_img = dataio.lin2img(gt['img'], image_resolution) gt_dense = gt_img else: gt_img = dataio.lin2img(gt['img'], image_resolution) * mask gt_dense = gt_img pred_img = dataio.lin2img(model_output['model_out'], image_resolution) if inner: prefix = 'inner_train_' else: prefix = 'outer_train_' output_vs_gt = torch.cat((gt_img, pred_img), dim=-1) # p = make_grid(output_vs_gt, scale_each=False, normalize=True) # imageio.imwrite(f'./test/test_{total_steps}.png', p.permute(1,2,0).cpu().detach().numpy()) writer.add_image(prefix + 'gt_vs_pred', make_grid(output_vs_gt, scale_each=False, normalize=True), global_step=total_steps) write_psnr(pred_img, gt_dense, writer, total_steps, prefix + 'img_dense_') min_max_summary(prefix + 'coords', model_input['coords'], writer, total_steps) min_max_summary(prefix + 'pred_img', pred_img, writer, total_steps) min_max_summary(prefix + 'gt_img', gt_img, writer, total_steps)
def write_image_summary_small(image_resolution, mask, model, model_input, gt, model_output, writer, total_steps, prefix='train_'): if mask is None: gt_img = dataio.lin2img(gt['img'], image_resolution) gt_dense = gt_img else: gt_img = dataio.lin2img(gt['img'], image_resolution) * mask gt_dense = gt_img pred_img = dataio.lin2img(model_output['model_out'], image_resolution) with torch.no_grad(): img_gradient = torch.autograd.grad(model_output['model_out'], [model_output['model_in']], grad_outputs=torch.ones_like(model_output['model_out']), create_graph=True, retain_graph=True)[0] grad_norm = img_gradient.norm(dim=-1, keepdim=True) grad_norm = dataio.lin2img(grad_norm, image_resolution) writer.add_image(prefix + 'pred_grad_norm', make_grid(grad_norm, scale_each=False, normalize=True), global_step=total_steps) output_vs_gt = torch.cat((gt_img, pred_img), dim=-1) writer.add_image(prefix + 'gt_vs_pred', make_grid(output_vs_gt, scale_each=False, normalize=True), global_step=total_steps) write_psnr(pred_img, gt_dense, writer, total_steps, prefix + 'img_dense_') min_max_summary(prefix + 'coords', model_input['coords'], writer, total_steps) min_max_summary(prefix + 'pred_img', pred_img, writer, total_steps) min_max_summary(prefix + 'gt_img', gt_img, writer, total_steps) hypernet_activation_summary(model, model_input, gt, model_output, writer, total_steps, prefix)
def getTestMSE(dataloader, subdir): MSEs = [] total_steps = 0 utils.cond_mkdir(os.path.join(root_path, subdir)) utils.cond_mkdir(os.path.join(root_path, 'ground_truth')) with tqdm(total=len(dataloader)) as pbar: for step, (model_input, gt) in enumerate(dataloader): model_input['idx'] = torch.Tensor([model_input['idx']]).long() model_input = { key: value.cuda() for key, value in model_input.items() } gt = {key: value.cuda() for key, value in gt.items()} with torch.no_grad(): model_output = model(model_input) out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute( 1, 2, 0).detach().cpu().numpy() out_img += 1 out_img /= 2. out_img = np.clip(out_img, 0., 1.) gt_img = dataio.lin2img(gt['img'], image_resolution).squeeze().permute( 1, 2, 0).detach().cpu().numpy() gt_img += 1 gt_img /= 2. gt_img = np.clip(gt_img, 0., 1.) sparse_img = model_input['img_sparse'].squeeze().detach().cpu( ).permute(1, 2, 0).numpy() mask = np.sum((sparse_img == 0), axis=2) == 3 sparse_img += 1 sparse_img /= 2. sparse_img = np.clip(sparse_img, 0., 1.) sparse_img[mask, ...] = 1. imageio.imwrite( os.path.join(root_path, subdir, str(total_steps) + '_sparse.png'), to_uint8(sparse_img)) imageio.imwrite( os.path.join(root_path, subdir, str(total_steps) + '.png'), to_uint8(out_img)) imageio.imwrite( os.path.join(root_path, 'ground_truth', str(total_steps) + '.png'), to_uint8(gt_img)) MSE = np.mean((out_img - gt_img)**2) MSEs.append(MSE) pbar.update(1) total_steps += 1 return MSEs
def write_sdf_summary(model, model_input, gt, model_output, writer, total_steps, prefix='train_'): slice_coords_2d = dataio.get_mgrid(512) with torch.no_grad(): yz_slice_coords = torch.cat( (torch.zeros_like(slice_coords_2d[:, :1]), slice_coords_2d), dim=-1) yz_slice_model_input = {'coords': yz_slice_coords.cuda()[None, ...]} yz_model_out = model(yz_slice_model_input) sdf_values = yz_model_out['model_out'] sdf_values = dataio.lin2img(sdf_values).squeeze().cpu().numpy() fig = make_contour_plot(sdf_values) writer.add_figure(prefix + 'yz_sdf_slice', fig, global_step=total_steps) xz_slice_coords = torch.cat( (slice_coords_2d[:, :1], torch.zeros_like( slice_coords_2d[:, :1]), slice_coords_2d[:, -1:]), dim=-1) xz_slice_model_input = {'coords': xz_slice_coords.cuda()[None, ...]} xz_model_out = model(xz_slice_model_input) sdf_values = xz_model_out['model_out'] sdf_values = dataio.lin2img(sdf_values).squeeze().cpu().numpy() fig = make_contour_plot(sdf_values) writer.add_figure(prefix + 'xz_sdf_slice', fig, global_step=total_steps) xy_slice_coords = torch.cat( (slice_coords_2d[:, :2], -0.75 * torch.ones_like(slice_coords_2d[:, :1])), dim=-1) xy_slice_model_input = {'coords': xy_slice_coords.cuda()[None, ...]} xy_model_out = model(xy_slice_model_input) sdf_values = xy_model_out['model_out'] sdf_values = dataio.lin2img(sdf_values).squeeze().cpu().numpy() fig = make_contour_plot(sdf_values) writer.add_figure(prefix + 'xy_sdf_slice', fig, global_step=total_steps) min_max_summary(prefix + 'model_out_min_max', model_output['model_out'], writer, total_steps) min_max_summary(prefix + 'coords', model_input['coords'], writer, total_steps)
def write_image_summary(image_resolution, model, model_input, gt, model_output, writer, total_steps, prefix='train_'): gt_img = dataio.lin2img(gt['img'], image_resolution) pred_img = dataio.lin2img(model_output['model_out'], image_resolution) img_gradient = diff_operators.gradient(model_output['model_out'], model_output['model_in']) img_laplace = diff_operators.laplace(model_output['model_out'], model_output['model_in']) output_vs_gt = torch.cat((gt_img, pred_img), dim=-1) writer.add_image(prefix + 'gt_vs_pred', make_grid(output_vs_gt, scale_each=False, normalize=True), global_step=total_steps) pred_img = dataio.rescale_img((pred_img+1)/2, mode='clamp').permute(0,2,3,1).squeeze(0).detach().cpu().numpy() pred_grad = dataio.grads2img(dataio.lin2img(img_gradient)).permute(1,2,0).squeeze().detach().cpu().numpy() pred_lapl = cv2.cvtColor(cv2.applyColorMap(dataio.to_uint8(dataio.rescale_img( dataio.lin2img(img_laplace), perc=2).permute(0,2,3,1).squeeze(0).detach().cpu().numpy()), cmapy.cmap('RdBu')), cv2.COLOR_BGR2RGB) gt_img = dataio.rescale_img((gt_img+1) / 2, mode='clamp').permute(0, 2, 3, 1).squeeze(0).detach().cpu().numpy() gt_grad = dataio.grads2img(dataio.lin2img(gt['gradients'])).permute(1, 2, 0).squeeze().detach().cpu().numpy() gt_lapl = cv2.cvtColor(cv2.applyColorMap(dataio.to_uint8(dataio.rescale_img( dataio.lin2img(gt['laplace']), perc=2).permute(0, 2, 3, 1).squeeze(0).detach().cpu().numpy()), cmapy.cmap('RdBu')), cv2.COLOR_BGR2RGB) writer.add_image(prefix + 'pred_img', torch.from_numpy(pred_img).permute(2, 0, 1), global_step=total_steps) writer.add_image(prefix + 'pred_grad', torch.from_numpy(pred_grad).permute(2, 0, 1), global_step=total_steps) writer.add_image(prefix + 'pred_lapl', torch.from_numpy(pred_lapl).permute(2,0,1), global_step=total_steps) writer.add_image(prefix + 'gt_img', torch.from_numpy(gt_img).permute(2,0,1), global_step=total_steps) writer.add_image(prefix + 'gt_grad', torch.from_numpy(gt_grad).permute(2, 0, 1), global_step=total_steps) writer.add_image(prefix + 'gt_lapl', torch.from_numpy(gt_lapl).permute(2, 0, 1), global_step=total_steps) write_psnr(dataio.lin2img(model_output['model_out'], image_resolution), dataio.lin2img(gt['img'], image_resolution), writer, total_steps, prefix+'img_')
def write_summaries(model_output, model_input, gt, writer, total_steps, prefix): gt_sds = dataio.lin2img(gt['sds']).squeeze().cpu() pred_sds = dataio.lin2img(model_output).squeeze().detach().cpu() """ # plot level sets batch_size = model_input['level_set'].shape[0] fig, axes = plt.subplots(-(-batch_size // 8), 8) levelset_points = model_input['level_set'].detach().cpu().numpy() if batch_size > 1: for i in range(batch_size): num_level_set_points = (gt['ls_sds'][i] == 0.).shape[0] digit = levelset_points[i, :num_level_set_points, :] im = axes[i//8, i%8].scatter(digit[:,1], -digit[:,0]) axes[i//8, i%8].axis('off') plt.tight_layout() else: num_level_set_points = (gt['ls_sds'][0] == 0.).shape[0] digit = levelset_points[0, :num_level_set_points, :] im = axes[0].scatter(digit[:,1], -digit[:,0]) axes[0].axis('off') writer.add_figure(prefix + 'levelset', fig, global_step=total_steps) """ """ output_vs_gt = torch.cat((gt_sds, pred_sds), dim=-1)[:,None,...] writer.add_image(prefix + 'gt_vs_pred', make_grid(output_vs_gt, scale_each=False, normalize=True), global_step=total_steps) """ writer.add_scalar(prefix + 'gt_min', gt_sds.min().detach().cpu().numpy(), total_steps) writer.add_scalar(prefix + 'gt_max', gt_sds.max().detach().cpu().numpy(), total_steps) writer.add_scalar(prefix + 'pred_min', pred_sds.min().detach().cpu().numpy(), total_steps) writer.add_scalar(prefix + 'pred_max', pred_sds.max().detach().cpu().numpy(), total_steps) writer.add_scalar(prefix + 'dense_coords_min', gt['sds'].min().detach().cpu().numpy(), total_steps) writer.add_scalar(prefix + 'dense_coords_max', gt['sds'].min().detach().cpu().numpy(), total_steps) writer.flush()
def write_meta_summaries(model_output, meta_batch, writer, total_steps, prefix): gt_sds = dataio.lin2img(meta_batch['test'][1]).squeeze().cpu() pred_sds = dataio.lin2img(model_output).squeeze().detach().cpu() valid_levelset_points = (meta_batch['train'][1] == 0.).repeat(1, 1, 2) if valid_levelset_points.any(): level_set = meta_batch['train'][0] batch_size = level_set.shape[0] fig, axes = plt.subplots(int(batch_size / 8), 8) levelset_points = level_set.detach().cpu().numpy() for i in range(batch_size): num_level_set_points = (meta_batch['train'][1][i] == 0.).shape[0] digit = levelset_points[i, :num_level_set_points, :] im = axes[i // 8, i % 8].scatter(digit[:, 1], -digit[:, 0]) axes[i // 8, i % 8].axis('off') plt.tight_layout() writer.add_figure(prefix + 'levelset', fig, global_step=total_steps) output_vs_gt = torch.cat((gt_sds, pred_sds), dim=-1)[:, None, ...] writer.add_image(prefix + 'gt_vs_pred', make_grid(output_vs_gt, scale_each=False, normalize=True), global_step=total_steps) writer.add_scalar(prefix + 'gt_min', gt_sds.min().detach().cpu().numpy(), total_steps) writer.add_scalar(prefix + 'gt_max', gt_sds.max().detach().cpu().numpy(), total_steps) writer.add_scalar(prefix + 'pred_min', pred_sds.min().detach().cpu().numpy(), total_steps) writer.add_scalar(prefix + 'pred_max', pred_sds.max().detach().cpu().numpy(), total_steps) writer.add_scalar(prefix + 'dense_coords_min', meta_batch['test'][0].min().detach().cpu().numpy(), total_steps) writer.add_scalar(prefix + 'dense_coords_max', meta_batch['test'][0].min().detach().cpu().numpy(), total_steps)
def write_laplace_summary(model, model_input, gt, model_output, writer, total_steps, prefix='train_'): # Plot comparison images gt_img = dataio.lin2img(gt['img']) pred_img = dataio.lin2img(model_output['model_out']) output_vs_gt = torch.cat((dataio.rescale_img(gt_img), dataio.rescale_img(pred_img,perc=1e-2)), dim=-1) writer.add_image(prefix + 'comp_gt_vs_pred', make_grid(output_vs_gt, scale_each=False, normalize=True), global_step=total_steps) # Plot comparisons laplacian (this is what has been fitted) gt_laplace = dataio.lin2img(gt['laplace']) pred_laplace = diff_operators.laplace(model_output['model_out'], model_output['model_in']) pred_laplace = dataio.lin2img(pred_laplace) output_vs_gt_laplace = torch.cat((gt_laplace, pred_laplace), dim=-1) writer.add_image(prefix + 'comp_gt_vs_pred_laplace', make_grid(output_vs_gt_laplace, scale_each=False, normalize=True), global_step=total_steps) # Plot image gradient img_gradient = diff_operators.gradient(model_output['model_out'], model_output['model_in']) grads_img = dataio.grads2img(dataio.lin2img(img_gradient)) writer.add_image(prefix + 'pred_grad', make_grid(grads_img, scale_each=False, normalize=True), global_step=total_steps) # Plot gt image writer.add_image(prefix + 'gt_img', make_grid(gt_img, scale_each=False, normalize=True), global_step=total_steps) # Plot gt laplacian # writer.add_image(prefix + 'gt_laplace', make_grid(gt_laplace, scale_each=False, normalize=True), # global_step=total_steps) gt_laplace_img = dataio.to_uint8(dataio.to_numpy(dataio.rescale_img(gt_laplace, 'scale', 1))) gt_laplace_img = cv2.applyColorMap(gt_laplace_img.squeeze(), cmapy.cmap('RdBu')) gt_laplace_img = cv2.cvtColor(gt_laplace_img, cv2.COLOR_BGR2RGB) writer.add_image(prefix + 'gt_lapl', torch.from_numpy(gt_laplace_img).permute(2, 0, 1), global_step=total_steps) # Plot pred image writer.add_image(prefix + 'pred_img', make_grid(pred_img, scale_each=False, normalize=True), global_step=total_steps) # Plot pred gradient pred_gradients = diff_operators.gradient(model_output['model_out'], model_output['model_in']) pred_grads_img = dataio.grads2img(dataio.lin2img(pred_gradients)) writer.add_image(prefix + 'pred_grad', make_grid(pred_grads_img, scale_each=False, normalize=True), global_step=total_steps) # Plot pred laplacian # writer.add_image(prefix + 'pred_lapl', make_grid(pred_laplace, scale_each=False, normalize=True), # global_step=total_steps) pred_laplace_img = dataio.to_uint8(dataio.to_numpy(dataio.rescale_img(pred_laplace,'scale',1))) pred_laplace_img = cv2.applyColorMap(pred_laplace_img.squeeze(),cmapy.cmap('RdBu')) pred_laplace_img = cv2.cvtColor(pred_laplace_img, cv2.COLOR_BGR2RGB) writer.add_image(prefix + 'pred_lapl', torch.from_numpy(pred_laplace_img).permute(2,0,1), global_step=total_steps) min_max_summary(prefix + 'coords', model_input['coords'], writer, total_steps) min_max_summary(prefix + 'gt_laplace', gt_laplace, writer, total_steps) min_max_summary(prefix + 'pred_laplace', pred_laplace, writer, total_steps) min_max_summary(prefix + 'pred_img', pred_img, writer, total_steps) min_max_summary(prefix + 'gt_img', gt_img, writer, total_steps)
def write_gradcomp_summary(model, model_input, gt, model_output, writer, total_steps, prefix='train_'): # Plot gt gradients (this is what has been fitted) gt_gradients = gt['gradients'] gt_grads_img = dataio.grads2img(dataio.lin2img(gt_gradients)) pred_gradients = diff_operators.gradient(model_output['model_out'], model_output['model_in']) pred_grads_img = dataio.grads2img(dataio.lin2img(pred_gradients)) output_vs_gt_gradients = torch.cat((gt_grads_img, pred_grads_img), dim=-1) writer.add_image(prefix + 'comp_gt_vs_pred_gradients', make_grid(output_vs_gt_gradients, scale_each=False, normalize=True), global_step=total_steps) # Plot gt gt_grads1 = gt['grads1'] gt_grads1_img = dataio.grads2img(dataio.lin2img(gt_grads1)) gt_grads2 = gt['grads2'] gt_grads2_img = dataio.grads2img(dataio.lin2img(gt_grads2)) writer.add_image(prefix + 'gt_grads1', make_grid(gt_grads1_img, scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'gt_grads2', make_grid(gt_grads2_img, scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'gt_gradcomp', make_grid(gt_grads_img, scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'pred_gradcomp', make_grid(pred_grads_img, scale_each=False, normalize=True), global_step=total_steps) # Plot gt image gt_img1 = dataio.lin2img(gt['img1']) gt_img2 = dataio.lin2img(gt['img2']) writer.add_image(prefix + 'gt_img1', make_grid(gt_img1, scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'gt_img2', make_grid(gt_img2, scale_each=False, normalize=True), global_step=total_steps) # Plot pred compo image pred_img = dataio.rescale_img(dataio.lin2img(model_output['model_out'])) writer.add_image(prefix + 'pred_comp_img', make_grid(pred_img, scale_each=False, normalize=True), global_step=total_steps) min_max_summary(prefix + 'coords', model_input['coords'], writer, total_steps) min_max_summary(prefix + 'gt_laplace', gt_gradients, writer, total_steps) min_max_summary(prefix + 'pred_img', pred_img, writer, total_steps)
def write_helmholtz_summary(model, model_input, gt, model_output, writer, total_steps, prefix='train_'): sl = 256 coords = dataio.get_mgrid(sl)[None, ...].cuda() def scale_percentile(pred, min_perc=1, max_perc=99): min = np.percentile(pred.cpu().numpy(), 1) max = np.percentile(pred.cpu().numpy(), 99) pred = torch.clamp(pred, min, max) return (pred - min) / (max - min) with torch.no_grad(): if 'coords_sub' in model_input: summary_model_input = { 'coords': coords.repeat(min(2, model_input['coords_sub'].shape[0]), 1, 1) } summary_model_input['coords_sub'] = model_input['coords_sub'][:2, ...] summary_model_input['img_sub'] = model_input['img_sub'][:2, ...] pred = model(summary_model_input)['model_out'] else: pred = model({'coords': coords})['model_out'] if 'pretrain' in gt: gt['squared_slowness_grid'] = pred[..., -1, None].clone() + 1. if torch.all(gt['pretrain'] == -1): gt['squared_slowness_grid'] = torch.clamp( pred[..., -1, None].clone(), min=-0.999) + 1. gt['squared_slowness_grid'] = torch.where( (torch.abs(coords[..., 0, None]) > 0.75) | (torch.abs(coords[..., 1, None]) > 0.75), torch.ones_like(gt['squared_slowness_grid']), gt['squared_slowness_grid']) pred = pred[..., :-1] pred = dataio.lin2img(pred) pred_cmpl = pred[..., 0::2, :, :].cpu().numpy( ) + 1j * pred[..., 1::2, :, :].cpu().numpy() pred_angle = torch.from_numpy(np.angle(pred_cmpl)) pred_mag = torch.from_numpy(np.abs(pred_cmpl)) min_max_summary(prefix + 'coords', model_input['coords'], writer, total_steps) min_max_summary(prefix + 'pred_real', pred[..., 0::2, :, :], writer, total_steps) min_max_summary( prefix + 'pred_abs', torch.sqrt(pred[..., 0::2, :, :]**2 + pred[..., 1::2, :, :]**2), writer, total_steps) min_max_summary(prefix + 'squared_slowness', gt['squared_slowness_grid'], writer, total_steps) pred = scale_percentile(pred) pred_angle = scale_percentile(pred_angle) pred_mag = scale_percentile(pred_mag) pred = pred.permute(1, 0, 2, 3) pred_mag = pred_mag.permute(1, 0, 2, 3) pred_angle = pred_angle.permute(1, 0, 2, 3) writer.add_image(prefix + 'pred_real', make_grid(pred[0::2, :, :, :], scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'pred_imaginary', make_grid(pred[1::2, :, :, :], scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'pred_angle', make_grid(pred_angle, scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'pred_mag', make_grid(pred_mag, scale_each=False, normalize=True), global_step=total_steps) if 'gt' in gt: gt_field = dataio.lin2img(gt['gt']) gt_field_cmpl = gt_field[..., 0, :, :].cpu().numpy( ) + 1j * gt_field[..., 1, :, :].cpu().numpy() gt_angle = torch.from_numpy(np.angle(gt_field_cmpl)) gt_mag = torch.from_numpy(np.abs(gt_field_cmpl)) gt_field = scale_percentile(gt_field) gt_angle = scale_percentile(gt_angle) gt_mag = scale_percentile(gt_mag) writer.add_image(prefix + 'gt_real', make_grid(gt_field[..., 0, :, :], scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'gt_imaginary', make_grid(gt_field[..., 1, :, :], scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'gt_angle', make_grid(gt_angle, scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'gt_mag', make_grid(gt_mag, scale_each=False, normalize=True), global_step=total_steps) min_max_summary(prefix + 'gt_real', gt_field[..., 0, :, :], writer, total_steps) velocity = torch.sqrt(1 / dataio.lin2img(gt['squared_slowness_grid']))[:1] min_max_summary(prefix + 'velocity', velocity[..., 0, :, :], writer, total_steps) velocity = scale_percentile(velocity) writer.add_image(prefix + 'velocity', make_grid(velocity[..., 0, :, :], scale_each=False, normalize=True), global_step=total_steps) if 'squared_slowness_grid' in gt: writer.add_image(prefix + 'squared_slowness', make_grid(dataio.lin2img( gt['squared_slowness_grid'])[:2, :1], scale_each=False, normalize=True), global_step=total_steps) if 'img_sub' in model_input: writer.add_image(prefix + 'img', make_grid(dataio.lin2img( model_input['img_sub'])[:2, :1], scale_each=False, normalize=True), global_step=total_steps) if isinstance(model, meta_modules.NeuralProcessImplicit2DHypernetBVP): hypernet_activation_summary(model, model_input, gt, model_output, writer, total_steps, prefix)
root_path = os.path.join(opt.logging_root, opt.experiment_name) utils.cond_mkdir(root_path) # Load checkpoint model.load_state_dict(torch.load(opt.checkpoint_path)) # First experiment: Upsample training image model_input = { 'coords': dataio.get_mgrid(image_resolution)[None, :].cuda(), 'img_sparse': coord_dataset_train[0][0]['img_sparse'].unsqueeze(0).cuda() } model_output = model(model_input) out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute( 1, 2, 0).detach().cpu().numpy() out_img += 1 out_img /= 2. out_img = np.clip(out_img, 0., 1.) imageio.imwrite(os.path.join(root_path, 'upsampled_train.png'), out_img) # Second experiment: sample larger range model_input = { 'coords': dataio.get_mgrid(image_resolution)[None, :].cuda() * 5, 'img_sparse': coord_dataset_train[0][0]['img_sparse'].unsqueeze(0).cuda() } model_output = model(model_input) out_img = dataio.lin2img(model_output['model_out'],
def getTestMSE(dataloader, subdir): MSEs = [] PSNRs = [] total_steps = 0 utils.cond_mkdir(os.path.join(root_path, subdir)) utils.cond_mkdir(os.path.join(root_path, 'ground_truth')) with tqdm(total=len(dataloader)) as pbar: for step, (model_input, gt) in enumerate(dataloader): model_input['idx'] = torch.Tensor([model_input['idx']]).long() model_input = { key: value.cuda() for key, value in model_input.items() } gt = {key: value.cuda() for key, value in gt.items()} with torch.no_grad(): model_output = model(model_input) out_img = dataio.lin2img(model_output['model_out'], image_resolution).squeeze().permute( 1, 2, 0).detach().cpu().numpy() out_img += 1 out_img /= 2. out_img = np.clip(out_img, 0., 1.) gt_img = dataio.lin2img(gt['img'], image_resolution).squeeze().permute( 1, 2, 0).detach().cpu().numpy() gt_img += 1 gt_img /= 2. gt_img = np.clip(gt_img, 0., 1.) sparse_img = np.ones((image_resolution[0], image_resolution[1], 3)) coords_sub = model_input['coords_sub'].squeeze().detach().cpu( ).numpy() rgb_sub = model_input['img_sub'].squeeze().detach().cpu().numpy() for index in range(0, coords_sub.shape[0]): r = int(round((coords_sub[index][0] + 1) / 2 * 31)) c = int(round((coords_sub[index][1] + 1) / 2 * 31)) sparse_img[r, c, :] = np.clip((rgb_sub[index, :] + 1) / 2, 0., 1.) imageio.imwrite( os.path.join(root_path, subdir, str(total_steps) + '_sparse.png'), to_uint8(sparse_img)) imageio.imwrite( os.path.join(root_path, subdir, str(total_steps) + '.png'), to_uint8(out_img)) imageio.imwrite( os.path.join(root_path, 'ground_truth', str(total_steps) + '.png'), to_uint8(gt_img)) MSE = np.mean((out_img - gt_img)**2) MSEs.append(MSE) PSNR = skimage.measure.compare_psnr(out_img, gt_img, data_range=1) PSNRs.append(PSNR) pbar.update(1) total_steps += 1 return MSEs, PSNRs