def val(self, epoch):
        self.model.eval()
        batch_l1 = 0.0
        batch_mse = 0.0
        batch_psnr = 0.0
        batch_img_l1 = 0.0
        batch_img_mse = 0.0
        batch_img_psnr = 0.0
        batch_output_vis = []
        batch_diff_vis = []
        batch_target_vis = []

        with torch.no_grad():
            for idx, sample in enumerate(self.dataloader_val):
                pts = sample['points'][0].to(device)
                viewdirs = sample['viewdirs'][0].to(device)
                target_rgb = sample['target'][0].to(device)
                # mask = sample['mask'][0].to(device)

                assert (pts.shape[0] == target_rgb.shape[0])

                batch_pts = [
                    pts[i:i + self.cfg.n_points_in_batch]
                    for i in range(0, pts.shape[0], self.cfg.n_points_in_batch)
                ]
                batch_viewdirs = [
                    viewdirs[i:i + self.cfg.n_points_in_batch] for i in range(
                        0, viewdirs.shape[0], self.cfg.n_points_in_batch)
                ]
                batch_target = [
                    target_rgb[i:i + self.cfg.n_points_in_batch] for i in
                    range(0, target_rgb.shape[0], self.cfg.n_points_in_batch)
                ]
                # batch_mask = [mask[i : i + self.cfg.n_points_in_batch] for i in range(0, mask.shape[0], self.cfg.n_points_in_batch)]

                pred_rgb = []

                for batch_id in range(len(batch_pts)):
                    target = batch_target[batch_id]
                    # mask = batch_mask[batch_id].unsqueeze(1).expand(target.shape)
                    mask = torch.ones(size=target.shape,
                                      dtype=torch.bool,
                                      device=target.device)
                    input_pts = batch_pts[batch_id]
                    if self.cfg.use_viewdirs:
                        input_viewdir = batch_viewdirs[batch_id].unsqueeze(
                            1).expand(input_pts.shape)
                        if self.cfg.model_name == 'NeRF':
                            input_pts = torch.cat([input_pts, input_viewdir],
                                                  dim=-1)
                        else:
                            input_viewdir = torch.reshape(
                                input_viewdir, [-1, 3])

                    pts_shape = input_pts.shape
                    input_pts = torch.reshape(input_pts, [-1, pts_shape[-1]])
                    # input_pts = data_utils.input_mapping(input_pts, self.input_map, self.cfg.map_points, self.cfg.map_viewdirs, self.cfg.points_type, self.cfg.model_name)

                    # Run network

                    if self.cfg.model_name == 'NeRF':
                        raw = self.model(input_pts)
                    else:
                        raw = self.model(input_pts, input_viewdir)
                    raw = torch.reshape(raw, list(pts_shape[:-1]) + [4])

                    # Compute opacities and colors
                    rgb, sigma_a = raw[..., :3], raw[..., 3]
                    sigma_a = torch.nn.functional.relu(sigma_a)
                    rgb = torch.sigmoid(rgb)

                    if self.cfg.n_samples != 1:
                        z_vals = sample['z_vals'][0].to(device)
                        one_e_10 = torch.tensor([1e10],
                                                dtype=torch.float32,
                                                device=device)
                        dists = torch.cat(
                            (
                                z_vals[..., 1:] - z_vals[..., :-1],
                                one_e_10.expand(z_vals[..., :1].shape),
                            ),
                            dim=-1,
                        )
                        alpha = 1.0 - torch.exp(-sigma_a * dists)
                    else:
                        alpha = 1.0 - torch.exp(-sigma_a)
                    weights = alpha * data_utils.cumprod_exclusive(1.0 -
                                                                   alpha +
                                                                   1e-10)

                    rgb = (weights[..., None] * rgb).sum(dim=-2)

                    # compute losses
                    loss_mse = metrics.mse(target, rgb, mask)
                    loss_l1 = metrics.l1(target, rgb, mask)
                    loss_psnr = metrics.psnr(target, rgb, mask)

                    # log
                    batch_img_l1 += loss_l1.item() * (
                        input_pts.shape[0] / (pts.shape[0] * pts.shape[1]))
                    batch_img_mse += loss_mse.item() * (
                        input_pts.shape[0] / (pts.shape[0] * pts.shape[1]))
                    batch_img_psnr += loss_psnr.item() * (
                        input_pts.shape[0] / (pts.shape[0] * pts.shape[1]))
                    # batch_img_l1 += loss_l1.item()
                    # batch_img_mse += loss_mse.item()
                    # batch_img_psnr += loss_psnr.item()
                    pred_rgb.append(rgb.detach())

                batch_l1 += batch_img_l1  #/(batch_id+1)
                batch_mse += batch_img_mse  #/(batch_id+1)
                batch_psnr += batch_img_psnr  #/(batch_id+1)

                # visualize images on tensorboard
                if idx in [0, 1, 2, 3, 4
                           ] and (epoch + 1) % self.cfg.save_every == 0:
                    pred_rgb = torch.cat(pred_rgb, dim=0)
                    res = int(np.sqrt(pred_rgb.shape[0]))
                    output = torch.reshape(pred_rgb,
                                           [res, res, 3]).permute(2, 0, 1)
                    output = torch.clamp(output, min=0.0, max=1.0)
                    target_rgb = torch.reshape(target_rgb,
                                               [res, res, 3]).permute(2, 0, 1)
                    batch_target_vis.append(target_rgb)
                    batch_diff_vis.append(torch.abs(target_rgb - output))
                    batch_output_vis.append(output)

        # log losses
        self.writer.add_scalar('val_mse', batch_mse / (idx + 1), epoch + 1)
        self.writer.add_scalar('val_l1', batch_l1 / (idx + 1), epoch + 1)
        self.writer.add_scalar('val_psnr', batch_psnr / (idx + 1), epoch + 1)

        if (epoch + 1) % self.cfg.save_every == 0:
            self.writer.add_images('val_target',
                                   torch.stack(batch_target_vis, dim=0),
                                   epoch + 1)
            self.writer.add_images('val_output',
                                   torch.stack(batch_output_vis, dim=0),
                                   epoch + 1)
            self.writer.add_images('val_diff',
                                   torch.stack(batch_diff_vis, dim=0),
                                   epoch + 1)
    def train(self, epoch):
        self.model.train()
        batch_loss = 0.0
        batch_clamped_output = []
        batch_diff = []
        batch_target = []
        batch_input_points = []
        batch_input_viewdirs = []

        for idx, sample in enumerate(self.dataloader_train):
            input = sample['input'].to(device)
            mask = sample['input_mask'].to(device)
            target = sample['target'].to(device)

            target = target[:, :3, ...]
            mask = mask.unsqueeze(1).expand(target.shape)

            self.optimizer.zero_grad()
            # forward pass
            output = self.model(input)

            # compute losses
            loss_l1 = metrics.l1(target, output, mask)
            loss_ssim = metrics.msssim(target, output, mask)
            loss = loss_l1 + loss_ssim
            # loss = self.loss_fun(target, output, mask)

            # backward pass and optimize
            loss.backward()
            self.optimizer.step()

            # log
            batch_loss += loss.item()

            # visualize images on tensorboard
            if idx in [0, 1] and (epoch + 1) % self.cfg.save_every == 0:
                clamped_output = torch.clamp(output.detach(), min=0.0,
                                             max=1.0) * mask
                target = target * mask
                batch_target.append(target.cpu())
                batch_diff.append(torch.abs(target - clamped_output).cpu())
                batch_clamped_output.append(clamped_output.cpu())

            # if idx in [0,1] and epoch == 0:
            #     raw_data = sample['raw_data'].cpu()
            #     if self.cfg.use_viewdirs:
            #         batch_input_viewdirs.append(raw_data[:,1,...])
            #         raw_data = raw_data[:,0,...]
            #     batch_input_points.append(raw_data)

            # visualize alpha and rgb distribution for first image on tensorboard
            # if idx == 0:
            #     self.writer.add_histogram("red", output[0,0,...], epoch+1)
            #     self.writer.add_histogram("green", rgb[0,1,...], epoch+1)
            #     self.writer.add_histogram("blue", rgb[0,2,...], epoch+1)

        # log losses
        self.writer.add_scalar('rgb_loss', batch_loss / (idx + 1), epoch + 1)

        # log input and target images only once
        # if epoch == 0:
        #     batch_input_points = torch.cat(batch_input_points)
        #     if batch_input_points.shape[1] == 3:
        #         batch_input_points = visualize.vis_cartesian_as_matplotfig(batch_input_points)
        #     else:
        #         batch_input_points = visualize.vis_spherical_as_matplotfig(batch_input_points)

        #     if self.cfg.use_viewdirs:
        #         batch_input_viewdirs = torch.cat(batch_input_viewdirs)
        #         if batch_input_viewdirs.shape[1] == 3:
        #             batch_input_viewdirs = visualize.vis_cartesian_as_matplotfig(batch_input_viewdirs)
        #         else:
        #             batch_input_viewdirs = visualize.vis_spherical_as_matplotfig(batch_input_viewdirs)

        #     self.writer.add_figure('input_points', batch_input_points,epoch+1)
        #     if self.cfg.use_viewdirs:
        #         self.writer.add_figure('input_viewdirs', batch_input_viewdirs,epoch+1)

        if (epoch + 1) % self.cfg.save_every == 0:
            self.writer.add_images('rgb_target', torch.cat(batch_target),
                                   epoch + 1)
            self.writer.add_images('rgb_clamped',
                                   torch.cat(batch_clamped_output), epoch + 1)
            self.writer.add_images('rgb_diff', torch.cat(batch_diff),
                                   epoch + 1)

        return batch_loss / (idx + 1)
    def val(self, epoch):
        self.model.eval()
        batch_l1 = 0.0
        batch_lpips = 0.0
        batch_psnr = 0.0
        batch_ssim = 0.0
        batch_mse = 0.0
        batch_fft = 0.0
        batch_clamped_output = []
        batch_diff = []
        batch_target = []
        batch_input_points = []
        batch_input_viewdirs = []

        with torch.no_grad():
            for idx, sample in enumerate(self.dataloader_val):
                input = sample['input'].to(device)
                mask = sample['input_mask'].to(device)
                target = sample['target'].to(device)

                target = target[:, :3, ...]
                mask = mask.unsqueeze(1).expand(target.shape)

                # forward pass
                output = self.model(input)

                # compute losses
                # lpips = metrics.lpips(target, output, mask)
                l1 = metrics.l1(target, output, mask)
                # mse = metrics.mse(target, output, mask)
                psnr = metrics.psnr(target, output, mask)
                ssim = metrics.msssim(target, output, mask)
                # fft = metrics.loss_fft(target, output, mask)

                # log
                batch_l1 += l1.item()
                # batch_mse += mse.item()
                # batch_lpips += lpips.item()
                batch_psnr += psnr.item()
                batch_ssim += ssim.item()
                # batch_fft += fft.item()

                # visualize images on tensorboard
                if idx in [0, 1] and (epoch + 1) % self.cfg.save_every == 0:
                    clamped_output = torch.clamp(output, min=0.0,
                                                 max=1.0) * mask
                    target = target * mask
                    batch_diff.append(torch.abs(target - clamped_output).cpu())
                    batch_clamped_output.append(clamped_output.cpu())

                if epoch == 0 and idx in [0, 1]:
                    batch_target.append(target.cpu())
                    # raw_data = sample['raw_data'].cpu()
                    # if self.cfg.use_viewdirs:
                    #     batch_input_viewdirs.append(raw_data[:,1,...])
                    #     raw_data = raw_data[:,0,...]
                    # batch_input_points.append(raw_data)

        # log losses
        self.writer.add_scalar('rgb_val_loss', batch_l1 / (idx + 1), epoch + 1)
        # self.writer.add_scalar('rgb_val_mse',batch_mse/(idx+1),epoch+1)
        # self.writer.add_scalar('rgb_val_lpips',batch_lpips/(idx+1),epoch+1)
        self.writer.add_scalar('rgb_val_psnr', batch_psnr / (idx + 1),
                               epoch + 1)
        self.writer.add_scalar('rgb_val_ssim', batch_ssim / (idx + 1),
                               epoch + 1)
        # self.writer.add_scalar('val_fft',batch_fft/(idx+1),epoch+1)

        # log input and target images only once
        if epoch == 0:
            # batch_input_points = torch.cat(batch_input_points)
            # if batch_input_points.shape[1] == 3:
            #     batch_input_points = visualize.vis_cartesian_as_matplotfig(batch_input_points)
            # else:
            #     batch_input_points = visualize.vis_spherical_as_matplotfig(batch_input_points)

            # if self.cfg.use_viewdirs:
            #     batch_input_viewdirs = torch.cat(batch_input_viewdirs)
            #     if batch_input_viewdirs.shape[1] == 3:
            #         batch_input_viewdirs = visualize.vis_cartesian_as_matplotfig(batch_input_viewdirs)
            #     else:
            #         batch_input_viewdirs = visualize.vis_spherical_as_matplotfig(batch_input_viewdirs)

            # self.writer.add_figure('test_input_points', batch_input_points,epoch+1)
            # if self.cfg.use_viewdirs:
            #     self.writer.add_figure('test_input_viewdirs', batch_input_viewdirs,epoch+1)
            self.writer.add_images('rgb_val_target', torch.cat(batch_target),
                                   epoch + 1)

        if (epoch + 1) % self.cfg.save_every == 0:
            self.writer.add_images('rgb_val_clamped',
                                   torch.cat(batch_clamped_output), epoch + 1)
            self.writer.add_images('rgb_val_diff', torch.cat(batch_diff),
                                   epoch + 1)

        return batch_mse / (idx + 1)
    def val(self, epoch):
        self.model.eval()
        batch_alpha_loss = 0.0
        batch_rgb_loss = 0.0
        batch_alpha_psnr = 0.0
        batch_rgb_psnr = 0.0
        batch_clamped_alpha = []
        batch_blended_rgb = []
        batch_target_rgb = []
        batch_target_alpha = []

        with torch.no_grad():
            for idx, sample in enumerate(self.dataloader_val):
                input = sample['input'].to(device)
                mask = sample['input_mask'].to(device)
                target = sample['target'].to(device)

                target_alpha = target[:, 3, ...].unsqueeze_(1)
                target_rgb = target[:, :3, ...]
                rgb_mask = mask.unsqueeze(1).expand(target_rgb.shape)
                alpha_mask = mask.unsqueeze(1)

                # forward pass
                alpha, rgb = self.model(input)

                # compute losses
                alpha_loss = metrics.l1(target_alpha, alpha, alpha_mask)
                alpha_psnr = metrics.psnr(target_alpha, alpha, alpha_mask)

                clamped_alpha = torch.clamp(alpha, min=0.0, max=1.0)

                blended_rgb = clamped_alpha * rgb

                rgb_loss = metrics.l1(target_rgb, blended_rgb, rgb_mask)
                rgb_psnr = metrics.psnr(target_rgb, blended_rgb, rgb_mask)

                cb_rgb = torch.clamp(blended_rgb, min=0.0, max=1.0)

                # log
                batch_alpha_loss += alpha_loss.item()
                batch_rgb_loss += rgb_loss.item()
                batch_alpha_psnr += alpha_psnr.item()
                batch_rgb_psnr += rgb_psnr.item()

                # visualize images on tensorboard
                if idx in [0, 1, 2, 3
                           ] and (epoch + 1) % self.cfg.save_every == 0:
                    batch_clamped_alpha.append(clamped_alpha)
                    batch_blended_rgb.append(cb_rgb)

                if idx in [0, 1, 2, 3] and epoch == 0:
                    batch_target_alpha.append(target_alpha)
                    batch_target_rgb.append(target_rgb)

            # log losses
            self.writer.add_scalar('val_alpha_loss',
                                   batch_alpha_loss / (idx + 1), epoch + 1)
            self.writer.add_scalar('val_rgb_loss', batch_rgb_loss / (idx + 1),
                                   epoch + 1)
            self.writer.add_scalar('val_alpha_psnr',
                                   batch_alpha_psnr / (idx + 1), epoch + 1)
            self.writer.add_scalar('val_rgb_psnr', batch_rgb_psnr / (idx + 1),
                                   epoch + 1)

            # log input and target images only once
            if epoch == 0:
                self.writer.add_images('val_alpha_target',
                                       torch.cat(batch_target_alpha),
                                       epoch + 1)
                self.writer.add_images('val_rgb_target',
                                       torch.cat(batch_target_rgb), epoch + 1)

            if (epoch + 1) % self.cfg.save_every == 0:
                self.writer.add_images('val_alpha_clamped',
                                       torch.cat(batch_clamped_alpha),
                                       epoch + 1)
                self.writer.add_images('val_rgb_blended',
                                       torch.cat(batch_blended_rgb), epoch + 1)