def write_laplace_summary(model, model_input, gt, model_output, writer, total_steps, prefix='train_'): # Plot comparison images gt_img = dataio.lin2img(gt['img']) pred_img = dataio.lin2img(model_output['model_out']) output_vs_gt = torch.cat((dataio.rescale_img(gt_img), dataio.rescale_img(pred_img,perc=1e-2)), dim=-1) writer.add_image(prefix + 'comp_gt_vs_pred', make_grid(output_vs_gt, scale_each=False, normalize=True), global_step=total_steps) # Plot comparisons laplacian (this is what has been fitted) gt_laplace = dataio.lin2img(gt['laplace']) pred_laplace = diff_operators.laplace(model_output['model_out'], model_output['model_in']) pred_laplace = dataio.lin2img(pred_laplace) output_vs_gt_laplace = torch.cat((gt_laplace, pred_laplace), dim=-1) writer.add_image(prefix + 'comp_gt_vs_pred_laplace', make_grid(output_vs_gt_laplace, scale_each=False, normalize=True), global_step=total_steps) # Plot image gradient img_gradient = diff_operators.gradient(model_output['model_out'], model_output['model_in']) grads_img = dataio.grads2img(dataio.lin2img(img_gradient)) writer.add_image(prefix + 'pred_grad', make_grid(grads_img, scale_each=False, normalize=True), global_step=total_steps) # Plot gt image writer.add_image(prefix + 'gt_img', make_grid(gt_img, scale_each=False, normalize=True), global_step=total_steps) # Plot gt laplacian # writer.add_image(prefix + 'gt_laplace', make_grid(gt_laplace, scale_each=False, normalize=True), # global_step=total_steps) gt_laplace_img = dataio.to_uint8(dataio.to_numpy(dataio.rescale_img(gt_laplace, 'scale', 1))) gt_laplace_img = cv2.applyColorMap(gt_laplace_img.squeeze(), cmapy.cmap('RdBu')) gt_laplace_img = cv2.cvtColor(gt_laplace_img, cv2.COLOR_BGR2RGB) writer.add_image(prefix + 'gt_lapl', torch.from_numpy(gt_laplace_img).permute(2, 0, 1), global_step=total_steps) # Plot pred image writer.add_image(prefix + 'pred_img', make_grid(pred_img, scale_each=False, normalize=True), global_step=total_steps) # Plot pred gradient pred_gradients = diff_operators.gradient(model_output['model_out'], model_output['model_in']) pred_grads_img = dataio.grads2img(dataio.lin2img(pred_gradients)) writer.add_image(prefix + 'pred_grad', make_grid(pred_grads_img, scale_each=False, normalize=True), global_step=total_steps) # Plot pred laplacian # writer.add_image(prefix + 'pred_lapl', make_grid(pred_laplace, scale_each=False, normalize=True), # global_step=total_steps) pred_laplace_img = dataio.to_uint8(dataio.to_numpy(dataio.rescale_img(pred_laplace,'scale',1))) pred_laplace_img = cv2.applyColorMap(pred_laplace_img.squeeze(),cmapy.cmap('RdBu')) pred_laplace_img = cv2.cvtColor(pred_laplace_img, cv2.COLOR_BGR2RGB) writer.add_image(prefix + 'pred_lapl', torch.from_numpy(pred_laplace_img).permute(2,0,1), global_step=total_steps) min_max_summary(prefix + 'coords', model_input['coords'], writer, total_steps) min_max_summary(prefix + 'gt_laplace', gt_laplace, writer, total_steps) min_max_summary(prefix + 'pred_laplace', pred_laplace, writer, total_steps) min_max_summary(prefix + 'pred_img', pred_img, writer, total_steps) min_max_summary(prefix + 'gt_img', gt_img, writer, total_steps)
def gradients_color_mse(model_output, gt): # compute gradients on the model gradients_r = diff_operators.gradient(model_output['model_out'][..., 0], model_output['model_in']) gradients_g = diff_operators.gradient(model_output['model_out'][..., 1], model_output['model_in']) gradients_b = diff_operators.gradient(model_output['model_out'][..., 2], model_output['model_in']) gradients = torch.cat((gradients_r, gradients_g, gradients_b), dim=-1) # compare them with the ground-truth weights = torch.tensor([1e1, 1e1, 1., 1., 1e1, 1e1]).cuda() gradients_loss = torch.mean((weights * (gradients[0:2] - gt['gradients']).pow(2)).sum(-1)) return {'gradients_loss': gradients_loss}
def image_mse_TV_prior(mask, k1, model, model_output, gt): coords_rand = 2 * (torch.rand((model_output['model_in'].shape[0], model_output['model_in'].shape[1] // 2, model_output['model_in'].shape[2])).cuda() - 0.5) rand_input = {'coords': coords_rand} rand_output = model(rand_input) if mask is None: return {'img_loss': ((model_output['model_out'] - gt['img']) ** 2).mean(), 'prior_loss': k1 * (torch.abs(diff_operators.gradient( rand_output['model_out'], rand_output['model_in']))).mean()} else: return {'img_loss': (mask * (model_output['model_out'] - gt['img']) ** 2).mean(), 'prior_loss': k1 * (torch.abs(diff_operators.gradient( rand_output['model_out'], rand_output['model_in']))).mean()}
def sdf(model_output, gt): ''' x: batch of input coordinates y: usually the output of the trial_soln function ''' gt_sdf = gt['sdf'] gt_normals = gt['normals'] coords = model_output['model_in'] pred_sdf = model_output['model_out'] gradient = diff_operators.gradient(pred_sdf, coords) # Wherever boundary_values is not equal to zero, we interpret it as a boundary constraint. sdf_constraint = torch.where(gt_sdf != -1, pred_sdf, torch.zeros_like(pred_sdf)) inter_constraint = torch.where(gt_sdf != -1, torch.zeros_like(pred_sdf), torch.exp(-1e2 * torch.abs(pred_sdf))) normal_constraint = torch.where( gt_sdf != -1, 1 - F.cosine_similarity(gradient, gt_normals, dim=-1)[..., None], torch.zeros_like(gradient[..., :1])) grad_constraint = torch.abs(gradient.norm(dim=-1) - 1) # Exp # Lapl # ----------------- return { 'sdf': torch.abs(sdf_constraint).mean() * 3e3, # 1e4 # 3e3 'inter': inter_constraint.mean() * 1e2, # 1e2 # 1e3 'normal_constraint': normal_constraint.mean() * 1e2, # 1e2 'grad_constraint': grad_constraint.mean() * 5e1 } # 1e1 # 5e1
def gradients_mse(model_output, gt): # compute gradients on the model gradients = diff_operators.gradient(model_output['model_out'], model_output['model_in']) # compare them with the ground-truth gradients_loss = torch.mean((gradients - gt['gradients']).pow(2).sum(-1)) return {'gradients_loss': gradients_loss}
def write_image_summary(image_resolution, model, model_input, gt, model_output, writer, total_steps, prefix='train_'): gt_img = dataio.lin2img(gt['img'], image_resolution) pred_img = dataio.lin2img(model_output['model_out'], image_resolution) img_gradient = diff_operators.gradient(model_output['model_out'], model_output['model_in']) img_laplace = diff_operators.laplace(model_output['model_out'], model_output['model_in']) output_vs_gt = torch.cat((gt_img, pred_img), dim=-1) writer.add_image(prefix + 'gt_vs_pred', make_grid(output_vs_gt, scale_each=False, normalize=True), global_step=total_steps) pred_img = dataio.rescale_img((pred_img+1)/2, mode='clamp').permute(0,2,3,1).squeeze(0).detach().cpu().numpy() pred_grad = dataio.grads2img(dataio.lin2img(img_gradient)).permute(1,2,0).squeeze().detach().cpu().numpy() pred_lapl = cv2.cvtColor(cv2.applyColorMap(dataio.to_uint8(dataio.rescale_img( dataio.lin2img(img_laplace), perc=2).permute(0,2,3,1).squeeze(0).detach().cpu().numpy()), cmapy.cmap('RdBu')), cv2.COLOR_BGR2RGB) gt_img = dataio.rescale_img((gt_img+1) / 2, mode='clamp').permute(0, 2, 3, 1).squeeze(0).detach().cpu().numpy() gt_grad = dataio.grads2img(dataio.lin2img(gt['gradients'])).permute(1, 2, 0).squeeze().detach().cpu().numpy() gt_lapl = cv2.cvtColor(cv2.applyColorMap(dataio.to_uint8(dataio.rescale_img( dataio.lin2img(gt['laplace']), perc=2).permute(0, 2, 3, 1).squeeze(0).detach().cpu().numpy()), cmapy.cmap('RdBu')), cv2.COLOR_BGR2RGB) writer.add_image(prefix + 'pred_img', torch.from_numpy(pred_img).permute(2, 0, 1), global_step=total_steps) writer.add_image(prefix + 'pred_grad', torch.from_numpy(pred_grad).permute(2, 0, 1), global_step=total_steps) writer.add_image(prefix + 'pred_lapl', torch.from_numpy(pred_lapl).permute(2,0,1), global_step=total_steps) writer.add_image(prefix + 'gt_img', torch.from_numpy(gt_img).permute(2,0,1), global_step=total_steps) writer.add_image(prefix + 'gt_grad', torch.from_numpy(gt_grad).permute(2, 0, 1), global_step=total_steps) writer.add_image(prefix + 'gt_lapl', torch.from_numpy(gt_lapl).permute(2, 0, 1), global_step=total_steps) write_psnr(dataio.lin2img(model_output['model_out'], image_resolution), dataio.lin2img(gt['img'], image_resolution), writer, total_steps, prefix+'img_')
def write_gradcomp_summary(model, model_input, gt, model_output, writer, total_steps, prefix='train_'): # Plot gt gradients (this is what has been fitted) gt_gradients = gt['gradients'] gt_grads_img = dataio.grads2img(dataio.lin2img(gt_gradients)) pred_gradients = diff_operators.gradient(model_output['model_out'], model_output['model_in']) pred_grads_img = dataio.grads2img(dataio.lin2img(pred_gradients)) output_vs_gt_gradients = torch.cat((gt_grads_img, pred_grads_img), dim=-1) writer.add_image(prefix + 'comp_gt_vs_pred_gradients', make_grid(output_vs_gt_gradients, scale_each=False, normalize=True), global_step=total_steps) # Plot gt gt_grads1 = gt['grads1'] gt_grads1_img = dataio.grads2img(dataio.lin2img(gt_grads1)) gt_grads2 = gt['grads2'] gt_grads2_img = dataio.grads2img(dataio.lin2img(gt_grads2)) writer.add_image(prefix + 'gt_grads1', make_grid(gt_grads1_img, scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'gt_grads2', make_grid(gt_grads2_img, scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'gt_gradcomp', make_grid(gt_grads_img, scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'pred_gradcomp', make_grid(pred_grads_img, scale_each=False, normalize=True), global_step=total_steps) # Plot gt image gt_img1 = dataio.lin2img(gt['img1']) gt_img2 = dataio.lin2img(gt['img2']) writer.add_image(prefix + 'gt_img1', make_grid(gt_img1, scale_each=False, normalize=True), global_step=total_steps) writer.add_image(prefix + 'gt_img2', make_grid(gt_img2, scale_each=False, normalize=True), global_step=total_steps) # Plot pred compo image pred_img = dataio.rescale_img(dataio.lin2img(model_output['model_out'])) writer.add_image(prefix + 'pred_comp_img', make_grid(pred_img, scale_each=False, normalize=True), global_step=total_steps) min_max_summary(prefix + 'coords', model_input['coords'], writer, total_steps) min_max_summary(prefix + 'gt_laplace', gt_gradients, writer, total_steps) min_max_summary(prefix + 'pred_img', pred_img, writer, total_steps)