Example #1
0
def write_laplace_summary(model, model_input, gt, model_output, writer, total_steps, prefix='train_'):
    # Plot comparison images
    gt_img = dataio.lin2img(gt['img'])
    pred_img = dataio.lin2img(model_output['model_out'])

    output_vs_gt = torch.cat((dataio.rescale_img(gt_img), dataio.rescale_img(pred_img,perc=1e-2)), dim=-1)
    writer.add_image(prefix + 'comp_gt_vs_pred', make_grid(output_vs_gt, scale_each=False, normalize=True),
                     global_step=total_steps)

    # Plot comparisons laplacian (this is what has been fitted)
    gt_laplace = dataio.lin2img(gt['laplace'])
    pred_laplace = diff_operators.laplace(model_output['model_out'], model_output['model_in'])
    pred_laplace = dataio.lin2img(pred_laplace)

    output_vs_gt_laplace = torch.cat((gt_laplace, pred_laplace), dim=-1)
    writer.add_image(prefix + 'comp_gt_vs_pred_laplace', make_grid(output_vs_gt_laplace, scale_each=False, normalize=True),
                     global_step=total_steps)

    # Plot image gradient
    img_gradient = diff_operators.gradient(model_output['model_out'], model_output['model_in'])
    grads_img = dataio.grads2img(dataio.lin2img(img_gradient))
    writer.add_image(prefix + 'pred_grad', make_grid(grads_img, scale_each=False, normalize=True),
                     global_step=total_steps)

    # Plot gt image
    writer.add_image(prefix + 'gt_img', make_grid(gt_img, scale_each=False, normalize=True),
                     global_step=total_steps)

    # Plot gt laplacian
    # writer.add_image(prefix + 'gt_laplace', make_grid(gt_laplace, scale_each=False, normalize=True),
    #                  global_step=total_steps)
    gt_laplace_img = dataio.to_uint8(dataio.to_numpy(dataio.rescale_img(gt_laplace, 'scale', 1)))
    gt_laplace_img = cv2.applyColorMap(gt_laplace_img.squeeze(), cmapy.cmap('RdBu'))
    gt_laplace_img = cv2.cvtColor(gt_laplace_img, cv2.COLOR_BGR2RGB)
    writer.add_image(prefix + 'gt_lapl', torch.from_numpy(gt_laplace_img).permute(2, 0, 1), global_step=total_steps)

    # Plot pred image
    writer.add_image(prefix + 'pred_img', make_grid(pred_img, scale_each=False, normalize=True),
                     global_step=total_steps)

    # Plot pred gradient
    pred_gradients = diff_operators.gradient(model_output['model_out'], model_output['model_in'])
    pred_grads_img = dataio.grads2img(dataio.lin2img(pred_gradients))
    writer.add_image(prefix + 'pred_grad', make_grid(pred_grads_img, scale_each=False, normalize=True),
                     global_step=total_steps)

    # Plot pred laplacian
    # writer.add_image(prefix + 'pred_lapl', make_grid(pred_laplace, scale_each=False, normalize=True),
    #                  global_step=total_steps)
    pred_laplace_img = dataio.to_uint8(dataio.to_numpy(dataio.rescale_img(pred_laplace,'scale',1)))
    pred_laplace_img = cv2.applyColorMap(pred_laplace_img.squeeze(),cmapy.cmap('RdBu'))
    pred_laplace_img = cv2.cvtColor(pred_laplace_img, cv2.COLOR_BGR2RGB)
    writer.add_image(prefix + 'pred_lapl', torch.from_numpy(pred_laplace_img).permute(2,0,1), global_step=total_steps)

    min_max_summary(prefix + 'coords', model_input['coords'], writer, total_steps)
    min_max_summary(prefix + 'gt_laplace', gt_laplace, writer, total_steps)
    min_max_summary(prefix + 'pred_laplace', pred_laplace, writer, total_steps)
    min_max_summary(prefix + 'pred_img', pred_img, writer, total_steps)
    min_max_summary(prefix + 'gt_img', gt_img, writer, total_steps)
Example #2
0
def gradients_color_mse(model_output, gt):
    # compute gradients on the model
    gradients_r = diff_operators.gradient(model_output['model_out'][..., 0], model_output['model_in'])
    gradients_g = diff_operators.gradient(model_output['model_out'][..., 1], model_output['model_in'])
    gradients_b = diff_operators.gradient(model_output['model_out'][..., 2], model_output['model_in'])
    gradients = torch.cat((gradients_r, gradients_g, gradients_b), dim=-1)
    # compare them with the ground-truth
    weights = torch.tensor([1e1, 1e1, 1., 1., 1e1, 1e1]).cuda()
    gradients_loss = torch.mean((weights * (gradients[0:2] - gt['gradients']).pow(2)).sum(-1))
    return {'gradients_loss': gradients_loss}
Example #3
0
def image_mse_TV_prior(mask, k1, model, model_output, gt):
    coords_rand = 2 * (torch.rand((model_output['model_in'].shape[0],
                                   model_output['model_in'].shape[1] // 2,
                                   model_output['model_in'].shape[2])).cuda() - 0.5)
    rand_input = {'coords': coords_rand}
    rand_output = model(rand_input)

    if mask is None:
        return {'img_loss': ((model_output['model_out'] - gt['img']) ** 2).mean(),
                'prior_loss': k1 * (torch.abs(diff_operators.gradient(
                    rand_output['model_out'], rand_output['model_in']))).mean()}
    else:
        return {'img_loss': (mask * (model_output['model_out'] - gt['img']) ** 2).mean(),
                'prior_loss': k1 * (torch.abs(diff_operators.gradient(
                    rand_output['model_out'], rand_output['model_in']))).mean()}
Example #4
0
def sdf(model_output, gt):
    '''
       x: batch of input coordinates
       y: usually the output of the trial_soln function
       '''
    gt_sdf = gt['sdf']
    gt_normals = gt['normals']

    coords = model_output['model_in']
    pred_sdf = model_output['model_out']

    gradient = diff_operators.gradient(pred_sdf, coords)

    # Wherever boundary_values is not equal to zero, we interpret it as a boundary constraint.
    sdf_constraint = torch.where(gt_sdf != -1, pred_sdf,
                                 torch.zeros_like(pred_sdf))
    inter_constraint = torch.where(gt_sdf != -1, torch.zeros_like(pred_sdf),
                                   torch.exp(-1e2 * torch.abs(pred_sdf)))
    normal_constraint = torch.where(
        gt_sdf != -1,
        1 - F.cosine_similarity(gradient, gt_normals, dim=-1)[..., None],
        torch.zeros_like(gradient[..., :1]))
    grad_constraint = torch.abs(gradient.norm(dim=-1) - 1)
    # Exp      # Lapl
    # -----------------
    return {
        'sdf': torch.abs(sdf_constraint).mean() * 3e3,  # 1e4      # 3e3
        'inter': inter_constraint.mean() * 1e2,  # 1e2                   # 1e3
        'normal_constraint': normal_constraint.mean() * 1e2,  # 1e2
        'grad_constraint': grad_constraint.mean() * 5e1
    }  # 1e1      # 5e1
Example #5
0
def gradients_mse(model_output, gt):
    # compute gradients on the model
    gradients = diff_operators.gradient(model_output['model_out'],
                                        model_output['model_in'])
    # compare them with the ground-truth
    gradients_loss = torch.mean((gradients - gt['gradients']).pow(2).sum(-1))
    return {'gradients_loss': gradients_loss}
Example #6
0
def write_image_summary(image_resolution, model, model_input, gt,
                        model_output, writer, total_steps, prefix='train_'):
    gt_img = dataio.lin2img(gt['img'], image_resolution)
    pred_img = dataio.lin2img(model_output['model_out'], image_resolution)

    img_gradient = diff_operators.gradient(model_output['model_out'], model_output['model_in'])
    img_laplace = diff_operators.laplace(model_output['model_out'], model_output['model_in'])

    output_vs_gt = torch.cat((gt_img, pred_img), dim=-1)
    writer.add_image(prefix + 'gt_vs_pred', make_grid(output_vs_gt, scale_each=False, normalize=True),
                     global_step=total_steps)

    pred_img = dataio.rescale_img((pred_img+1)/2, mode='clamp').permute(0,2,3,1).squeeze(0).detach().cpu().numpy()
    pred_grad = dataio.grads2img(dataio.lin2img(img_gradient)).permute(1,2,0).squeeze().detach().cpu().numpy()
    pred_lapl = cv2.cvtColor(cv2.applyColorMap(dataio.to_uint8(dataio.rescale_img(
                             dataio.lin2img(img_laplace), perc=2).permute(0,2,3,1).squeeze(0).detach().cpu().numpy()), cmapy.cmap('RdBu')), cv2.COLOR_BGR2RGB)

    gt_img = dataio.rescale_img((gt_img+1) / 2, mode='clamp').permute(0, 2, 3, 1).squeeze(0).detach().cpu().numpy()
    gt_grad = dataio.grads2img(dataio.lin2img(gt['gradients'])).permute(1, 2, 0).squeeze().detach().cpu().numpy()
    gt_lapl = cv2.cvtColor(cv2.applyColorMap(dataio.to_uint8(dataio.rescale_img(
        dataio.lin2img(gt['laplace']), perc=2).permute(0, 2, 3, 1).squeeze(0).detach().cpu().numpy()), cmapy.cmap('RdBu')), cv2.COLOR_BGR2RGB)

    writer.add_image(prefix + 'pred_img', torch.from_numpy(pred_img).permute(2, 0, 1), global_step=total_steps)
    writer.add_image(prefix + 'pred_grad', torch.from_numpy(pred_grad).permute(2, 0, 1), global_step=total_steps)
    writer.add_image(prefix + 'pred_lapl', torch.from_numpy(pred_lapl).permute(2,0,1), global_step=total_steps)
    writer.add_image(prefix + 'gt_img', torch.from_numpy(gt_img).permute(2,0,1), global_step=total_steps)
    writer.add_image(prefix + 'gt_grad', torch.from_numpy(gt_grad).permute(2, 0, 1), global_step=total_steps)
    writer.add_image(prefix + 'gt_lapl', torch.from_numpy(gt_lapl).permute(2, 0, 1), global_step=total_steps)

    write_psnr(dataio.lin2img(model_output['model_out'], image_resolution),
               dataio.lin2img(gt['img'], image_resolution), writer, total_steps, prefix+'img_')
Example #7
0
def write_gradcomp_summary(model, model_input, gt, model_output, writer, total_steps, prefix='train_'):
    # Plot gt gradients (this is what has been fitted)
    gt_gradients = gt['gradients']
    gt_grads_img = dataio.grads2img(dataio.lin2img(gt_gradients))

    pred_gradients = diff_operators.gradient(model_output['model_out'], model_output['model_in'])
    pred_grads_img = dataio.grads2img(dataio.lin2img(pred_gradients))

    output_vs_gt_gradients = torch.cat((gt_grads_img, pred_grads_img), dim=-1)
    writer.add_image(prefix + 'comp_gt_vs_pred_gradients', make_grid(output_vs_gt_gradients, scale_each=False, normalize=True),
                     global_step=total_steps)

    # Plot gt
    gt_grads1 = gt['grads1']
    gt_grads1_img = dataio.grads2img(dataio.lin2img(gt_grads1))

    gt_grads2 = gt['grads2']
    gt_grads2_img = dataio.grads2img(dataio.lin2img(gt_grads2))

    writer.add_image(prefix + 'gt_grads1', make_grid(gt_grads1_img, scale_each=False, normalize=True),
                     global_step=total_steps)
    writer.add_image(prefix + 'gt_grads2', make_grid(gt_grads2_img, scale_each=False, normalize=True),
                     global_step=total_steps)

    writer.add_image(prefix + 'gt_gradcomp', make_grid(gt_grads_img, scale_each=False, normalize=True),
                     global_step=total_steps)
    writer.add_image(prefix + 'pred_gradcomp', make_grid(pred_grads_img, scale_each=False, normalize=True),
                     global_step=total_steps)
    # Plot gt image
    gt_img1 = dataio.lin2img(gt['img1'])
    gt_img2 = dataio.lin2img(gt['img2'])
    writer.add_image(prefix + 'gt_img1', make_grid(gt_img1, scale_each=False, normalize=True),
                     global_step=total_steps)
    writer.add_image(prefix + 'gt_img2', make_grid(gt_img2, scale_each=False, normalize=True),
                     global_step=total_steps)

    # Plot pred compo image
    pred_img = dataio.rescale_img(dataio.lin2img(model_output['model_out']))
    writer.add_image(prefix + 'pred_comp_img', make_grid(pred_img, scale_each=False, normalize=True),
                     global_step=total_steps)

    min_max_summary(prefix + 'coords', model_input['coords'], writer, total_steps)
    min_max_summary(prefix + 'gt_laplace', gt_gradients, writer, total_steps)
    min_max_summary(prefix + 'pred_img', pred_img, writer, total_steps)