def val(self, epoch): self.model.eval() batch_l1 = 0.0 batch_mse = 0.0 batch_psnr = 0.0 batch_img_l1 = 0.0 batch_img_mse = 0.0 batch_img_psnr = 0.0 batch_output_vis = [] batch_diff_vis = [] batch_target_vis = [] with torch.no_grad(): for idx, sample in enumerate(self.dataloader_val): pts = sample['points'][0].to(device) viewdirs = sample['viewdirs'][0].to(device) target_rgb = sample['target'][0].to(device) # mask = sample['mask'][0].to(device) assert (pts.shape[0] == target_rgb.shape[0]) batch_pts = [ pts[i:i + self.cfg.n_points_in_batch] for i in range(0, pts.shape[0], self.cfg.n_points_in_batch) ] batch_viewdirs = [ viewdirs[i:i + self.cfg.n_points_in_batch] for i in range( 0, viewdirs.shape[0], self.cfg.n_points_in_batch) ] batch_target = [ target_rgb[i:i + self.cfg.n_points_in_batch] for i in range(0, target_rgb.shape[0], self.cfg.n_points_in_batch) ] # batch_mask = [mask[i : i + self.cfg.n_points_in_batch] for i in range(0, mask.shape[0], self.cfg.n_points_in_batch)] pred_rgb = [] for batch_id in range(len(batch_pts)): target = batch_target[batch_id] # mask = batch_mask[batch_id].unsqueeze(1).expand(target.shape) mask = torch.ones(size=target.shape, dtype=torch.bool, device=target.device) input_pts = batch_pts[batch_id] if self.cfg.use_viewdirs: input_viewdir = batch_viewdirs[batch_id].unsqueeze( 1).expand(input_pts.shape) if self.cfg.model_name == 'NeRF': input_pts = torch.cat([input_pts, input_viewdir], dim=-1) else: input_viewdir = torch.reshape( input_viewdir, [-1, 3]) pts_shape = input_pts.shape input_pts = torch.reshape(input_pts, [-1, pts_shape[-1]]) # input_pts = data_utils.input_mapping(input_pts, self.input_map, self.cfg.map_points, self.cfg.map_viewdirs, self.cfg.points_type, self.cfg.model_name) # Run network if self.cfg.model_name == 'NeRF': raw = self.model(input_pts) else: raw = self.model(input_pts, input_viewdir) raw = torch.reshape(raw, list(pts_shape[:-1]) + [4]) # Compute opacities and colors rgb, sigma_a = raw[..., :3], raw[..., 3] sigma_a = torch.nn.functional.relu(sigma_a) rgb = torch.sigmoid(rgb) if self.cfg.n_samples != 1: z_vals = sample['z_vals'][0].to(device) one_e_10 = torch.tensor([1e10], dtype=torch.float32, device=device) dists = torch.cat( ( z_vals[..., 1:] - z_vals[..., :-1], one_e_10.expand(z_vals[..., :1].shape), ), dim=-1, ) alpha = 1.0 - torch.exp(-sigma_a * dists) else: alpha = 1.0 - torch.exp(-sigma_a) weights = alpha * data_utils.cumprod_exclusive(1.0 - alpha + 1e-10) rgb = (weights[..., None] * rgb).sum(dim=-2) # compute losses loss_mse = metrics.mse(target, rgb, mask) loss_l1 = metrics.l1(target, rgb, mask) loss_psnr = metrics.psnr(target, rgb, mask) # log batch_img_l1 += loss_l1.item() * ( input_pts.shape[0] / (pts.shape[0] * pts.shape[1])) batch_img_mse += loss_mse.item() * ( input_pts.shape[0] / (pts.shape[0] * pts.shape[1])) batch_img_psnr += loss_psnr.item() * ( input_pts.shape[0] / (pts.shape[0] * pts.shape[1])) # batch_img_l1 += loss_l1.item() # batch_img_mse += loss_mse.item() # batch_img_psnr += loss_psnr.item() pred_rgb.append(rgb.detach()) batch_l1 += batch_img_l1 #/(batch_id+1) batch_mse += batch_img_mse #/(batch_id+1) batch_psnr += batch_img_psnr #/(batch_id+1) # visualize images on tensorboard if idx in [0, 1, 2, 3, 4 ] and (epoch + 1) % self.cfg.save_every == 0: pred_rgb = torch.cat(pred_rgb, dim=0) res = int(np.sqrt(pred_rgb.shape[0])) output = torch.reshape(pred_rgb, [res, res, 3]).permute(2, 0, 1) output = torch.clamp(output, min=0.0, max=1.0) target_rgb = torch.reshape(target_rgb, [res, res, 3]).permute(2, 0, 1) batch_target_vis.append(target_rgb) batch_diff_vis.append(torch.abs(target_rgb - output)) batch_output_vis.append(output) # log losses self.writer.add_scalar('val_mse', batch_mse / (idx + 1), epoch + 1) self.writer.add_scalar('val_l1', batch_l1 / (idx + 1), epoch + 1) self.writer.add_scalar('val_psnr', batch_psnr / (idx + 1), epoch + 1) if (epoch + 1) % self.cfg.save_every == 0: self.writer.add_images('val_target', torch.stack(batch_target_vis, dim=0), epoch + 1) self.writer.add_images('val_output', torch.stack(batch_output_vis, dim=0), epoch + 1) self.writer.add_images('val_diff', torch.stack(batch_diff_vis, dim=0), epoch + 1)
def train(self, epoch): self.model.train() batch_loss = 0.0 batch_clamped_output = [] batch_diff = [] batch_target = [] batch_input_points = [] batch_input_viewdirs = [] for idx, sample in enumerate(self.dataloader_train): input = sample['input'].to(device) mask = sample['input_mask'].to(device) target = sample['target'].to(device) target = target[:, :3, ...] mask = mask.unsqueeze(1).expand(target.shape) self.optimizer.zero_grad() # forward pass output = self.model(input) # compute losses loss_l1 = metrics.l1(target, output, mask) loss_ssim = metrics.msssim(target, output, mask) loss = loss_l1 + loss_ssim # loss = self.loss_fun(target, output, mask) # backward pass and optimize loss.backward() self.optimizer.step() # log batch_loss += loss.item() # visualize images on tensorboard if idx in [0, 1] and (epoch + 1) % self.cfg.save_every == 0: clamped_output = torch.clamp(output.detach(), min=0.0, max=1.0) * mask target = target * mask batch_target.append(target.cpu()) batch_diff.append(torch.abs(target - clamped_output).cpu()) batch_clamped_output.append(clamped_output.cpu()) # if idx in [0,1] and epoch == 0: # raw_data = sample['raw_data'].cpu() # if self.cfg.use_viewdirs: # batch_input_viewdirs.append(raw_data[:,1,...]) # raw_data = raw_data[:,0,...] # batch_input_points.append(raw_data) # visualize alpha and rgb distribution for first image on tensorboard # if idx == 0: # self.writer.add_histogram("red", output[0,0,...], epoch+1) # self.writer.add_histogram("green", rgb[0,1,...], epoch+1) # self.writer.add_histogram("blue", rgb[0,2,...], epoch+1) # log losses self.writer.add_scalar('rgb_loss', batch_loss / (idx + 1), epoch + 1) # log input and target images only once # if epoch == 0: # batch_input_points = torch.cat(batch_input_points) # if batch_input_points.shape[1] == 3: # batch_input_points = visualize.vis_cartesian_as_matplotfig(batch_input_points) # else: # batch_input_points = visualize.vis_spherical_as_matplotfig(batch_input_points) # if self.cfg.use_viewdirs: # batch_input_viewdirs = torch.cat(batch_input_viewdirs) # if batch_input_viewdirs.shape[1] == 3: # batch_input_viewdirs = visualize.vis_cartesian_as_matplotfig(batch_input_viewdirs) # else: # batch_input_viewdirs = visualize.vis_spherical_as_matplotfig(batch_input_viewdirs) # self.writer.add_figure('input_points', batch_input_points,epoch+1) # if self.cfg.use_viewdirs: # self.writer.add_figure('input_viewdirs', batch_input_viewdirs,epoch+1) if (epoch + 1) % self.cfg.save_every == 0: self.writer.add_images('rgb_target', torch.cat(batch_target), epoch + 1) self.writer.add_images('rgb_clamped', torch.cat(batch_clamped_output), epoch + 1) self.writer.add_images('rgb_diff', torch.cat(batch_diff), epoch + 1) return batch_loss / (idx + 1)
def val(self, epoch): self.model.eval() batch_l1 = 0.0 batch_lpips = 0.0 batch_psnr = 0.0 batch_ssim = 0.0 batch_mse = 0.0 batch_fft = 0.0 batch_clamped_output = [] batch_diff = [] batch_target = [] batch_input_points = [] batch_input_viewdirs = [] with torch.no_grad(): for idx, sample in enumerate(self.dataloader_val): input = sample['input'].to(device) mask = sample['input_mask'].to(device) target = sample['target'].to(device) target = target[:, :3, ...] mask = mask.unsqueeze(1).expand(target.shape) # forward pass output = self.model(input) # compute losses # lpips = metrics.lpips(target, output, mask) l1 = metrics.l1(target, output, mask) # mse = metrics.mse(target, output, mask) psnr = metrics.psnr(target, output, mask) ssim = metrics.msssim(target, output, mask) # fft = metrics.loss_fft(target, output, mask) # log batch_l1 += l1.item() # batch_mse += mse.item() # batch_lpips += lpips.item() batch_psnr += psnr.item() batch_ssim += ssim.item() # batch_fft += fft.item() # visualize images on tensorboard if idx in [0, 1] and (epoch + 1) % self.cfg.save_every == 0: clamped_output = torch.clamp(output, min=0.0, max=1.0) * mask target = target * mask batch_diff.append(torch.abs(target - clamped_output).cpu()) batch_clamped_output.append(clamped_output.cpu()) if epoch == 0 and idx in [0, 1]: batch_target.append(target.cpu()) # raw_data = sample['raw_data'].cpu() # if self.cfg.use_viewdirs: # batch_input_viewdirs.append(raw_data[:,1,...]) # raw_data = raw_data[:,0,...] # batch_input_points.append(raw_data) # log losses self.writer.add_scalar('rgb_val_loss', batch_l1 / (idx + 1), epoch + 1) # self.writer.add_scalar('rgb_val_mse',batch_mse/(idx+1),epoch+1) # self.writer.add_scalar('rgb_val_lpips',batch_lpips/(idx+1),epoch+1) self.writer.add_scalar('rgb_val_psnr', batch_psnr / (idx + 1), epoch + 1) self.writer.add_scalar('rgb_val_ssim', batch_ssim / (idx + 1), epoch + 1) # self.writer.add_scalar('val_fft',batch_fft/(idx+1),epoch+1) # log input and target images only once if epoch == 0: # batch_input_points = torch.cat(batch_input_points) # if batch_input_points.shape[1] == 3: # batch_input_points = visualize.vis_cartesian_as_matplotfig(batch_input_points) # else: # batch_input_points = visualize.vis_spherical_as_matplotfig(batch_input_points) # if self.cfg.use_viewdirs: # batch_input_viewdirs = torch.cat(batch_input_viewdirs) # if batch_input_viewdirs.shape[1] == 3: # batch_input_viewdirs = visualize.vis_cartesian_as_matplotfig(batch_input_viewdirs) # else: # batch_input_viewdirs = visualize.vis_spherical_as_matplotfig(batch_input_viewdirs) # self.writer.add_figure('test_input_points', batch_input_points,epoch+1) # if self.cfg.use_viewdirs: # self.writer.add_figure('test_input_viewdirs', batch_input_viewdirs,epoch+1) self.writer.add_images('rgb_val_target', torch.cat(batch_target), epoch + 1) if (epoch + 1) % self.cfg.save_every == 0: self.writer.add_images('rgb_val_clamped', torch.cat(batch_clamped_output), epoch + 1) self.writer.add_images('rgb_val_diff', torch.cat(batch_diff), epoch + 1) return batch_mse / (idx + 1)
def val(self, epoch): self.model.eval() batch_alpha_loss = 0.0 batch_rgb_loss = 0.0 batch_alpha_psnr = 0.0 batch_rgb_psnr = 0.0 batch_clamped_alpha = [] batch_blended_rgb = [] batch_target_rgb = [] batch_target_alpha = [] with torch.no_grad(): for idx, sample in enumerate(self.dataloader_val): input = sample['input'].to(device) mask = sample['input_mask'].to(device) target = sample['target'].to(device) target_alpha = target[:, 3, ...].unsqueeze_(1) target_rgb = target[:, :3, ...] rgb_mask = mask.unsqueeze(1).expand(target_rgb.shape) alpha_mask = mask.unsqueeze(1) # forward pass alpha, rgb = self.model(input) # compute losses alpha_loss = metrics.l1(target_alpha, alpha, alpha_mask) alpha_psnr = metrics.psnr(target_alpha, alpha, alpha_mask) clamped_alpha = torch.clamp(alpha, min=0.0, max=1.0) blended_rgb = clamped_alpha * rgb rgb_loss = metrics.l1(target_rgb, blended_rgb, rgb_mask) rgb_psnr = metrics.psnr(target_rgb, blended_rgb, rgb_mask) cb_rgb = torch.clamp(blended_rgb, min=0.0, max=1.0) # log batch_alpha_loss += alpha_loss.item() batch_rgb_loss += rgb_loss.item() batch_alpha_psnr += alpha_psnr.item() batch_rgb_psnr += rgb_psnr.item() # visualize images on tensorboard if idx in [0, 1, 2, 3 ] and (epoch + 1) % self.cfg.save_every == 0: batch_clamped_alpha.append(clamped_alpha) batch_blended_rgb.append(cb_rgb) if idx in [0, 1, 2, 3] and epoch == 0: batch_target_alpha.append(target_alpha) batch_target_rgb.append(target_rgb) # log losses self.writer.add_scalar('val_alpha_loss', batch_alpha_loss / (idx + 1), epoch + 1) self.writer.add_scalar('val_rgb_loss', batch_rgb_loss / (idx + 1), epoch + 1) self.writer.add_scalar('val_alpha_psnr', batch_alpha_psnr / (idx + 1), epoch + 1) self.writer.add_scalar('val_rgb_psnr', batch_rgb_psnr / (idx + 1), epoch + 1) # log input and target images only once if epoch == 0: self.writer.add_images('val_alpha_target', torch.cat(batch_target_alpha), epoch + 1) self.writer.add_images('val_rgb_target', torch.cat(batch_target_rgb), epoch + 1) if (epoch + 1) % self.cfg.save_every == 0: self.writer.add_images('val_alpha_clamped', torch.cat(batch_clamped_alpha), epoch + 1) self.writer.add_images('val_rgb_blended', torch.cat(batch_blended_rgb), epoch + 1)