def get_camera_params(uv, pose, intrinsics): if pose.shape[1] == 7: #In case of quaternion vector representation cam_loc = pose[:, 4:] R = quat_to_rot(pose[:,:4]) p = utils.to_cuda(torch.eye(4).repeat(pose.shape[0],1,1)).float() p[:, :3, :3] = R p[:, :3, 3] = cam_loc else: # In case of pose matrix representation cam_loc = pose[:, :3, 3] p = pose batch_size, num_samples, _ = uv.shape depth = utils.to_cuda(torch.ones((batch_size, num_samples))) x_cam = uv[:, :, 0].view(batch_size, -1) y_cam = uv[:, :, 1].view(batch_size, -1) z_cam = depth.view(batch_size, -1) pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics) # permute for batch matrix product pixel_points_cam = pixel_points_cam.permute(0, 2, 1) world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3] ray_dirs = world_coords - cam_loc[:, None, :] ray_dirs = F.normalize(ray_dirs, dim=2) return ray_dirs, cam_loc
def interpolate(network, interval, experiment_directory, checkpoint, split_file, epoch, resolution, uniform_grid): with open(split_file, "r") as f: split = json.load(f) ds = utils.get_class(conf.get_string('train.dataset'))(split=split, dataset_path=conf.get_string('train.dataset_path'), with_normals=True) points_1, normals_1, index_1 = ds[0] points_2, normals_2, index_2 = ds[1] pnts = utils.to_cuda(torch.cat([points_1, points_2], dim=0)) name_1 = str.join('_', ds.get_info(0)) name_2 = str.join('_', ds.get_info(1)) name = name_1 + '_and_' + name_2 utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'interpolate')) utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'interpolate', str(checkpoint))) utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'interpolate', str(checkpoint), name)) my_path = os.path.join(experiment_directory, 'interpolate', str(checkpoint), name) latent_1 = optimize_latent(utils.to_cuda(points_1), utils.to_cuda(normals_1), conf, 800, network, 5e-3) latent_2 = optimize_latent(utils.to_cuda(points_2), utils.to_cuda(normals_2), conf, 800, network, 5e-3) pnts = torch.cat([latent_1.repeat(pnts.shape[0], 1), pnts], dim=-1) with torch.no_grad(): network.eval() for alpha in np.linspace(0,1, interval): latent = (latent_1 * (1-alpha)) + (latent_2 * alpha) plt.plot_surface(with_points=False, points=pnts, decoder=network, latent=latent, path=my_path, epoch=epoch, shapename=str(alpha), resolution=resolution, mc_value=0, is_uniform_grid=uniform_grid, verbose=False, save_html=False, save_ply=True, overwrite=True, connected=True)
def lift(x, y, z, intrinsics): # parse intrinsics intrinsics = utils.to_cuda(intrinsics) fx = intrinsics[:, 0, 0] fy = intrinsics[:, 1, 1] cx = intrinsics[:, 0, 2] cy = intrinsics[:, 1, 2] sk = intrinsics[:, 0, 1] x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z # homogeneous return torch.stack((x_lift, y_lift, z, utils.to_cuda(torch.ones_like(z))), dim=-1)
def get_pose_init(self): # get noisy initializations obtained with the linear method cam_file = '{0}/cameras_linear_init.npz'.format(self.instance_dir) camera_dict = np.load(cam_file) scale_mats = [ camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in self.views ] world_mats = [ camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in self.views ] init_pose = [] for scale_mat, world_mat in zip(scale_mats, world_mats): P = world_mat @ scale_mat P = P[:3, :4] _, pose = rend_util.load_K_Rt_from_P(None, P) init_pose.append(pose) init_pose = utils.to_cuda( torch.cat([ torch.Tensor(pose).float().unsqueeze(0) for pose in init_pose ], 0)) init_quat = rend_util.rot_to_quat(init_pose[:, :3, :3]) init_quat = torch.cat([init_quat, init_pose[:, :3, 3]], 1) return init_quat
def plot_validation_shapes(self, epoch, with_cuts=False): # plot network validation shapes with torch.no_grad(): print('plot validation epoch: ', epoch) self.network.eval() pnts, normals, idx = next(iter(self.eval_dataloader)) pnts = utils.to_cuda(pnts) pnts = self.add_latent(pnts, idx) latent = self.lat_vecs[idx[0]] shapename = str.join('_', self.ds.get_info(idx)) plot_surface(with_points=True, points=pnts, decoder=self.network, latent=latent, path=self.plots_dir, epoch=epoch, shapename=shapename, **self.conf.get_config('plot')) if with_cuts: plot_cuts(points=pnts, decoder=self.network, latent=latent, path=self.plots_dir, epoch=epoch, near_zero=False)
def minimal_sdf_points(self, num_pixels, sdf, cam_loc, ray_directions, mask, min_dis, max_dis): ''' Find points with minimal SDF value on rays for P_out pixels ''' n_mask_points = mask.sum() n = self.n_steps # steps = torch.linspace(0.0, 1.0,n).cuda() steps = utils.to_cuda(torch.empty(n).uniform_(0.0, 1.0)) mask_max_dis = max_dis[mask].unsqueeze(-1) mask_min_dis = min_dis[mask].unsqueeze(-1) steps = steps.unsqueeze(0).repeat( n_mask_points, 1) * (mask_max_dis - mask_min_dis) + mask_min_dis mask_points = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)[mask] mask_rays = ray_directions[mask, :] mask_points_all = mask_points.unsqueeze(1).repeat( 1, n, 1) + steps.unsqueeze(-1) * mask_rays.unsqueeze(1).repeat(1, n, 1) points = mask_points_all.reshape(-1, 3) mask_sdf_all = [] for pnts in torch.split(points, 100000, dim=0): mask_sdf_all.append(sdf(pnts)) mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n) min_vals, min_idx = mask_sdf_all.min(-1) min_mask_points = mask_points_all.reshape( -1, n, 3)[torch.arange(0, n_mask_points), min_idx] min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx] return min_mask_points, min_mask_dist
def get_depth(points, pose): ''' Retruns depth from 3D points according to camera pose ''' batch_size, num_samples, _ = points.shape if pose.shape[1] == 7: # In case of quaternion vector representation cam_loc = pose[:, 4:] R = quat_to_rot(pose[:, :4]) pose = utils.to_cuda(torch.eye(4).unsqueeze(0).repeat(batch_size, 1, 1)).float() pose[:, :3, 3] = cam_loc pose[:, :3, :3] = R points_hom = torch.cat((points, utils.to_cuda(torch.ones((batch_size, num_samples, 1)))), dim=2) # permute for batch matrix product points_hom = points_hom.permute(0, 2, 1) points_cam = torch.inverse(pose).bmm(points_hom) depth = points_cam[:, 2, :][:, :, None] return depth
def get_rgb_loss(self, rgb_values, rgb_gt, network_object_mask, object_mask): if (network_object_mask & object_mask).sum() == 0: return utils.to_cuda(torch.tensor(0.0)).float() rgb_values = rgb_values[network_object_mask & object_mask] rgb_gt = rgb_gt.reshape(-1, 3)[network_object_mask & object_mask] rgb_loss = self.l1_loss(rgb_values, rgb_gt) / float( object_mask.shape[0]) return rgb_loss
def get_mask_loss(self, sdf_output, network_object_mask, object_mask): mask = ~(network_object_mask & object_mask) if mask.sum() == 0: return utils.to_cuda(torch.tensor(0.0)).float() sdf_pred = -self.alpha * sdf_output[mask] gt = object_mask[mask].float() mask_loss = (1 / self.alpha) * F.binary_cross_entropy_with_logits( sdf_pred.squeeze(), gt, reduction='sum') / float( object_mask.shape[0]) return mask_loss
def add_latent(self, points, indices): batch_size, num_of_points, dim = points.shape points = points.reshape(batch_size * num_of_points, dim) latent_inputs = utils.to_cuda(torch.zeros(0)) for ind in indices.numpy(): latent_ind = self.lat_vecs[ind] latent_repeat = latent_ind.expand(num_of_points, -1) latent_inputs = torch.cat([latent_inputs, latent_repeat], 0) points = torch.cat([latent_inputs, points], 1) return points
def get_grid_uniform(resolution): x = np.linspace(-1.2, 1.2, resolution) y = x z = x xx, yy, zz = np.meshgrid(x, y, z) grid_points = utils.to_cuda(torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float)) return {"grid_points": grid_points, "shortest_axis_length": 2.4, "xyz": [x, y, z], "shortest_axis_index": 0}
def get_grid(points, resolution): eps = 0.2 input_min = torch.min(points, dim=0)[0].squeeze().numpy() input_max = torch.max(points, dim=0)[0].squeeze().numpy() bounding_box = input_max - input_min shortest_axis = np.argmin(bounding_box) if (shortest_axis == 0): x = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution) length = np.max(x) - np.min(x) y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) elif (shortest_axis == 1): y = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution) length = np.max(y) - np.min(y) x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) elif (shortest_axis == 2): z = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution) length = np.max(z) - np.min(z) x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) xx, yy, zz = np.meshgrid(x, y, z) grid_points = utils.to_cuda( torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float)) return { "grid_points": grid_points, "shortest_axis_length": length, "xyz": [x, y, z], "shortest_axis_index": shortest_axis }
def plot_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res): ground_true = (utils.to_cuda(ground_true) + 1.) / 2. rgb_points = (rgb_points + 1.) / 2. output_vs_gt = torch.cat((rgb_points, ground_true), dim=0) output_vs_gt_plot = lin2img(output_vs_gt, img_res) tensor = torchvision.utils.make_grid( output_vs_gt_plot, scale_each=False, normalize=False, nrow=plot_nrow).cpu().detach().numpy() tensor = tensor.transpose(1, 2, 0) scale_factor = 255 tensor = (tensor * scale_factor).astype(np.uint8) img = Image.fromarray(tensor) img.save('{0}/rendering_{1}.png'.format(path, epoch))
def quat_to_rot(q): batch_size, _ = q.shape q = F.normalize(q, dim=1) R = utils.to_cuda(torch.ones((batch_size, 3,3))) qr=q[:,0] qi = q[:, 1] qj = q[:, 2] qk = q[:, 3] R[:, 0, 0]=1-2 * (qj**2 + qk**2) R[:, 0, 1] = 2 * (qj *qi -qk*qr) R[:, 0, 2] = 2 * (qi * qk + qr * qj) R[:, 1, 0] = 2 * (qj * qi + qk * qr) R[:, 1, 1] = 1-2 * (qi**2 + qk**2) R[:, 1, 2] = 2*(qj*qk - qi*qr) R[:, 2, 0] = 2 * (qk * qi-qj * qr) R[:, 2, 1] = 2 * (qj*qk + qi*qr) R[:, 2, 2] = 1-2 * (qi**2 + qj**2) return R
def rot_to_quat(R): batch_size, _,_ = R.shape q = utils.to_cuda(torch.ones((batch_size, 4))) R00 = R[:, 0,0] R01 = R[:, 0, 1] R02 = R[:, 0, 2] R10 = R[:, 1, 0] R11 = R[:, 1, 1] R12 = R[:, 1, 2] R20 = R[:, 2, 0] R21 = R[:, 2, 1] R22 = R[:, 2, 2] q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2 q[:, 1]=(R21-R12)/(4*q[:,0]) q[:, 2] = (R02 - R20) / (4 * q[:, 0]) q[:, 3] = (R10 - R01) / (4 * q[:, 0]) return q
def forward(self, model_outputs, ground_truth): rgb_gt = utils.to_cuda(ground_truth['rgb']) network_object_mask = model_outputs['network_object_mask'] object_mask = model_outputs['object_mask'] rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt, network_object_mask, object_mask) mask_loss = self.get_mask_loss(model_outputs['sdf_output'], network_object_mask, object_mask) eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta']) loss = rgb_loss + \ self.eikonal_weight * eikonal_loss + \ self.mask_weight * mask_loss return { 'loss': loss, 'rgb_loss': rgb_loss, 'eikonal_loss': eikonal_loss, 'mask_loss': mask_loss, }
def get_sphere_intersection(cam_loc, ray_directions, r = 1.0): # Input: n_images x 4 x 4 ; n_images x n_rays x 3 # Output: n_images * n_rays x 2 (close and far) ; n_images * n_rays n_imgs, n_pix, _ = ray_directions.shape cam_loc = cam_loc.unsqueeze(-1) ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze() under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2,1) ** 2 - r ** 2) under_sqrt = under_sqrt.reshape(-1) mask_intersect = under_sqrt > 0 sphere_intersections = utils.to_cuda(torch.zeros(n_imgs * n_pix, 2)).float() sphere_intersections[mask_intersect] = torch.sqrt(under_sqrt[mask_intersect]).unsqueeze(-1) * utils.to_cuda(torch.Tensor([-1, 1])).float() sphere_intersections[mask_intersect] -= ray_cam_dot.reshape(-1)[mask_intersect].unsqueeze(-1) sphere_intersections = sphere_intersections.reshape(n_imgs, n_pix, 2) sphere_intersections = sphere_intersections.clamp_min(0.0) mask_intersect = mask_intersect.reshape(n_imgs, n_pix) return sphere_intersections, mask_intersect
def plot_cuts_axis(points, decoder, latent, path, epoch, near_zero, axis, file_name_sep='/'): onedim_cut = np.linspace(-1.0, 1.0, 200) xx, yy = np.meshgrid(onedim_cut, onedim_cut) xx = xx.ravel() yy = yy.ravel() min_axis = points[:, axis].min(dim=0)[0].item() max_axis = points[:, axis].max(dim=0)[0].item() mask = np.zeros(3) mask[axis] = 1.0 if axis == 0: position_cut = np.vstack(([np.zeros(xx.shape[0]), xx, yy])) elif axis == 1: position_cut = np.vstack(([xx, np.zeros(xx.shape[0]), yy])) elif axis == 2: position_cut = np.vstack(([xx, yy, np.zeros(xx.shape[0])])) position_cut = [position_cut + i * mask.reshape(-1, 1) for i in np.linspace(min_axis - 0.1, max_axis + 0.1, 50)] for index, pos in enumerate(position_cut): # fig = tools.make_subplots(rows=1, cols=1) field_input = utils.to_cuda(torch.tensor(pos.T, dtype=torch.float)) z = [] for i, pnts in enumerate(torch.split(field_input, 10000, dim=0)): if not latent is None: pnts = torch.cat([latent.expand(pnts.shape[0], -1), pnts], dim=1) z.append(decoder(pnts).detach().cpu().numpy()) z = np.concatenate(z, axis=0) if near_zero: if np.min(z) < -1.0e-5: start = -0.1 else: start = 0.0 trace1 = go.Contour(x=onedim_cut, y=onedim_cut, z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]), name='axis {0} = {1}'.format(axis, pos[axis, 0]), # colorbar=dict(len=0.4, y=0.8), autocontour=False, contours=dict( start=start, end=0.1, size=0.01 ) # ),colorbar = {'dtick': 0.05} ) else: trace1 = go.Contour(x=onedim_cut, y=onedim_cut, z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]), name='axis {0} = {1}'.format(axis, pos[axis, 0]), # colorbar=dict(len=0.4, y=0.8), autocontour=True, ncontours=70 # contours=dict( # start=-0.001, # end=0.001, # size=0.00001 # ) # ),colorbar = {'dtick': 0.05} ) layout = go.Layout(width=1200, height=1200, scene=dict(xaxis=dict(range=[-1, 1], autorange=False), yaxis=dict(range=[-1, 1], autorange=False), aspectratio=dict(x=1, y=1)), title=dict(text='axis {0} = {1}'.format(axis, pos[axis, 0]))) # fig['layout']['xaxis2'].update(range=[-1, 1]) # fig['layout']['yaxis2'].update(range=[-1, 1], scaleanchor="x2", scaleratio=1) filename = '{0}{1}cutsaxis_{2}_{3}_{4}.html'.format(path, file_name_sep, axis, epoch, index) fig1 = go.Figure(data=[trace1], layout=layout) offline.plot(fig1, filename=filename, auto_open=False)
def run(self): print("running") for epoch in range(self.startepoch, self.nepochs + 1): if epoch % self.conf.get_int('train.checkpoint_frequency') == 0: self.save_checkpoints(epoch) self.plot_validation_shapes(epoch) # change back to train mode self.network.train() self.adjust_learning_rate(epoch) # start epoch before_epoch = time() for data_index, (mnfld_pnts, normals, indices) in enumerate(self.train_dataloader): mnfld_pnts = utils.to_cuda(mnfld_pnts) if self.with_normals: normals = utils.to_cuda(normals) nonmnfld_pnts = self.sampler.get_points(mnfld_pnts) mnfld_pnts = self.add_latent(mnfld_pnts, indices) nonmnfld_pnts = self.add_latent(nonmnfld_pnts, indices) # forward pass mnfld_pnts.requires_grad_() nonmnfld_pnts.requires_grad_() mnfld_pred = self.network(mnfld_pnts) nonmnfld_pred = self.network(nonmnfld_pnts) mnfld_grad = gradient(mnfld_pnts, mnfld_pred) nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred) # manifold loss mnfld_loss = (mnfld_pred.abs()).mean() # eikonal loss grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1)**2).mean() loss = mnfld_loss + self.grad_lambda * grad_loss # normals loss if self.with_normals: normals = normals.view(-1, 3) normals_loss = ((mnfld_grad - normals).abs()).norm( 2, dim=1).mean() loss = loss + self.normals_lambda * normals_loss else: normals_loss = torch.zeros(1) # latent loss latent_loss = self.latent_size_reg(utils.to_cuda(indices)) loss = loss + self.latent_lambda * latent_loss # back propagation self.optimizer.zero_grad() loss.backward() self.optimizer.step() # print status if data_index % self.conf.get_int( 'train.status_frequency') == 0: print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\tManifold loss: {:.6f}' '\tGrad loss: {:.6f}\tLatent loss: {:.6f}\tNormals Loss: {:.6f}' .format(epoch, data_index * self.batch_size, len(self.ds), 100. * data_index / len(self.train_dataloader), loss.item(), mnfld_loss.item(), grad_loss.item(), latent_loss.item(), normals_loss.item())) after_epoch = time() print('epoch time {0}'.format(str(after_epoch - before_epoch)))
def plot_cuts(points, decoder, path, epoch, near_zero, latent=None): onedim_cut = np.linspace(-1, 1, 200) xx, yy = np.meshgrid(onedim_cut, onedim_cut) xx = xx.ravel() yy = yy.ravel() min_y = points[:, -2].min(dim=0)[0].item() max_y = points[:, -2].max(dim=0)[0].item() position_cut = np.vstack(([xx, np.zeros(xx.shape[0]), yy])) position_cut = [ position_cut + np.array([0., i, 0.]).reshape(-1, 1) for i in np.linspace(min_y - 0.1, max_y + 0.1, 10) ] for index, pos in enumerate(position_cut): #fig = tools.make_subplots(rows=1, cols=1) field_input = utils.to_cuda(torch.tensor(pos.T, dtype=torch.float)) z = [] for i, pnts in enumerate(torch.split(field_input, 1000, dim=-1)): input_ = pnts if (not latent is None): input_ = torch.cat([latent.expand(pnts.shape[0], -1), pnts], dim=1) z.append(decoder(input_).detach().cpu().numpy()) z = np.concatenate(z, axis=0) if (near_zero): trace1 = go.Contour( x=onedim_cut, y=onedim_cut, z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]), name='y = {0}'.format( pos[1, 0]), # colorbar=dict(len=0.4, y=0.8), autocontour=False, contours=dict(start=-0.001, end=0.001, size=0.00001) # ),colorbar = {'dtick': 0.05} ) else: trace1 = go.Contour( x=onedim_cut, y=onedim_cut, z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]), name='y = {0}'.format( pos[1, 0]), # colorbar=dict(len=0.4, y=0.8), autocontour=True, # contours=dict( # start=-0.001, # end=0.001, # size=0.00001 # ) # ),colorbar = {'dtick': 0.05} ) layout = go.Layout(width=1200, height=1200, scene=dict(xaxis=dict(range=[-1, 1], autorange=False), yaxis=dict(range=[-1, 1], autorange=False), aspectratio=dict(x=1, y=1)), title=dict(text='y = {0}'.format(pos[1, 0]))) # fig['layout']['xaxis2'].update(range=[-1, 1]) # fig['layout']['yaxis2'].update(range=[-1, 1], scaleanchor="x2", scaleratio=1) filename = '{0}/cuts{1}_{2}.html'.format(path, epoch, index) fig1 = go.Figure(data=[trace1], layout=layout) offline.plot(fig1, filename=filename, auto_open=False)
def __init__(self, **kwargs): # config setting self.home_dir = os.path.abspath(os.pardir) if type(kwargs['conf']) == str: self.conf_filename = os.path.abspath(kwargs['conf']) self.conf = ConfigFactory.parse_file(self.conf_filename) else: self.conf = kwargs['conf'] self.expname = kwargs['expname'] # GPU settings self.GPU_INDEX = kwargs['gpu_index'] if not self.GPU_INDEX == 'ignore': os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX) self.num_of_gpus = torch.cuda.device_count() # settings for loading an existing experiment if kwargs['is_continue'] and kwargs['timestamp'] == 'latest': if os.path.exists(os.path.join(self.home_dir, 'exps', self.expname)): timestamps = os.listdir( os.path.join(self.home_dir, 'exps', self.expname)) if (len(timestamps)) == 0: is_continue = False timestamp = None else: timestamp = sorted(timestamps)[-1] is_continue = True else: is_continue = False timestamp = None else: timestamp = kwargs['timestamp'] is_continue = kwargs['is_continue'] self.exps_folder_name = 'exps' utils.mkdir_ifnotexists( utils.concat_home_dir( os.path.join(self.home_dir, self.exps_folder_name))) self.expdir = utils.concat_home_dir( os.path.join(self.home_dir, self.exps_folder_name, self.expname)) utils.mkdir_ifnotexists(self.expdir) if is_continue: self.timestamp = timestamp else: self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) self.cur_exp_dir = self.timestamp utils.mkdir_ifnotexists(os.path.join(self.expdir, self.cur_exp_dir)) self.plots_dir = os.path.join(self.expdir, self.cur_exp_dir, 'plots') utils.mkdir_ifnotexists(self.plots_dir) self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir, 'checkpoints') utils.mkdir_ifnotexists(self.checkpoints_path) self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir, 'checkpoints') utils.mkdir_ifnotexists(self.checkpoints_path) self.model_params_subdir = "ModelParameters" self.optimizer_params_subdir = "OptimizerParameters" self.latent_codes_subdir = "LatentCodes" utils.mkdir_ifnotexists( os.path.join(self.checkpoints_path, self.model_params_subdir)) utils.mkdir_ifnotexists( os.path.join(self.checkpoints_path, self.optimizer_params_subdir)) utils.mkdir_ifnotexists( os.path.join(self.checkpoints_path, self.latent_codes_subdir)) self.nepochs = kwargs['nepochs'] self.batch_size = kwargs['batch_size'] if self.num_of_gpus > 0: self.batch_size *= self.num_of_gpus self.parallel = self.num_of_gpus > 1 self.global_sigma = self.conf.get_float( 'network.sampler.properties.global_sigma') self.local_sigma = self.conf.get_float( 'network.sampler.properties.local_sigma') self.sampler = Sampler.get_sampler( self.conf.get_string('network.sampler.sampler_type'))( self.global_sigma, self.local_sigma) train_split_file = os.path.abspath(kwargs['split_file']) print(f'Loading split file {train_split_file}') with open(train_split_file, "r") as f: train_split = json.load(f) print(f'Size of the split: {len(train_split)} samples') self.d_in = self.conf.get_int('train.d_in') # latent preprocessing self.latent_size = self.conf.get_int('train.latent_size') self.latent_lambda = self.conf.get_float('network.loss.latent_lambda') self.grad_lambda = self.conf.get_float('network.loss.lambda') self.normals_lambda = self.conf.get_float( 'network.loss.normals_lambda') self.with_normals = self.normals_lambda > 0 self.ds = utils.get_class(self.conf.get_string('train.dataset'))( split=train_split, with_normals=self.with_normals, dataset_path=self.conf.get_string('train.dataset_path'), points_batch=kwargs['points_batch'], ) self.num_scenes = len(self.ds) self.train_dataloader = torch.utils.data.DataLoader( self.ds, batch_size=self.batch_size, shuffle=True, num_workers=kwargs['threads'], drop_last=True, pin_memory=True) self.eval_dataloader = torch.utils.data.DataLoader(self.ds, batch_size=1, shuffle=False, num_workers=0, drop_last=True) self.network = utils.get_class( self.conf.get_string('train.network_class'))( d_in=(self.d_in + self.latent_size), **self.conf.get_config('network.inputs')) if self.parallel: self.network = torch.nn.DataParallel(self.network) if torch.cuda.is_available(): self.network.cuda() self.lr_schedules = self.get_learning_rate_schedules( self.conf.get_list('train.learning_rate_schedule')) self.weight_decay = self.conf.get_float('train.weight_decay') # optimizer and latent settings self.startepoch = 0 self.lat_vecs = utils.to_cuda( torch.zeros(self.num_scenes, self.latent_size)) self.lat_vecs.requires_grad_() self.optimizer = torch.optim.Adam([ { "params": self.network.parameters(), "lr": self.lr_schedules[0].get_learning_rate(0), "weight_decay": self.weight_decay }, { "params": self.lat_vecs, "lr": self.lr_schedules[1].get_learning_rate(0) }, ]) # if continue load checkpoints if is_continue: old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints') data = torch.load( os.path.join(old_checkpnts_dir, self.latent_codes_subdir, str(kwargs['checkpoint']) + '.pth')) self.lat_vecs = utils.to_cuda(data["latent_codes"]) saved_model_state = torch.load( os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) self.network.load_state_dict(saved_model_state["model_state_dict"]) data = torch.load( os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth")) self.optimizer.load_state_dict(data["optimizer_state_dict"]) self.startepoch = saved_model_state['epoch']
def __init__(self,**kwargs): torch.set_default_dtype(torch.float32) torch.set_num_threads(1) self.conf = ConfigFactory.parse_file(kwargs['conf']) self.batch_size = kwargs['batch_size'] self.nepochs = kwargs['nepochs'] self.exps_folder_name = kwargs['exps_folder_name'] self.GPU_INDEX = kwargs['gpu_index'] self.train_cameras = kwargs['train_cameras'] self.expname = self.conf.get_string('train.expname') + kwargs['expname'] scene_id = kwargs['scene_id'] if kwargs['scene_id'] else self.conf.get_string('dataset.scene_id', default=None) if scene_id: self.expname = self.expname + '_{0}'.format(scene_id) if kwargs['is_continue'] and kwargs['timestamp'] == 'latest': if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)): timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname)) if (len(timestamps)) == 0: is_continue = False timestamp = None else: timestamp = sorted(timestamps)[-1] is_continue = True else: is_continue = False timestamp = None else: timestamp = kwargs['timestamp'] is_continue = kwargs['is_continue'] utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name)) self.expdir = os.path.join('../', self.exps_folder_name, self.expname) utils.mkdir_ifnotexists(self.expdir) self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp)) self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots') utils.mkdir_ifnotexists(self.plots_dir) # create checkpoints dirs self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints') utils.mkdir_ifnotexists(self.checkpoints_path) self.model_params_subdir = "ModelParameters" self.optimizer_params_subdir = "OptimizerParameters" self.scheduler_params_subdir = "SchedulerParameters" utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir)) utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir)) utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.scheduler_params_subdir)) if self.train_cameras: self.optimizer_cam_params_subdir = "OptimizerCamParameters" self.cam_params_subdir = "CamParameters" utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_cam_params_subdir)) utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.cam_params_subdir)) os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.expdir, self.timestamp, 'runconf.conf'))) if (not self.GPU_INDEX == 'ignore'): os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX) print('shell command : {0}'.format(' '.join(sys.argv))) print('Loading data ...') dataset_conf = self.conf.get_config('dataset') if kwargs['scene_id']: dataset_conf['scene_id'] = kwargs['scene_id'] self.train_dataset = utils.get_class(self.conf.get_string('train.dataset_class'))(self.train_cameras, **dataset_conf) print('Finish loading data ...') self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.train_dataset.collate_fn ) self.plot_dataloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.conf.get_int('plot.plot_nimgs'), shuffle=True, collate_fn=self.train_dataset.collate_fn ) self.model = utils.get_class(self.conf.get_string('train.model_class'))(conf=self.conf.get_config('model')) if torch.cuda.is_available(): self.model.cuda() self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(**self.conf.get_config('loss')) self.lr = self.conf.get_float('train.learning_rate') self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) self.sched_milestones = self.conf.get_list('train.sched_milestones', default=[]) self.sched_factor = self.conf.get_float('train.sched_factor', default=0.0) self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, self.sched_milestones, gamma=self.sched_factor) # settings for camera optimization if self.train_cameras: num_images = len(self.train_dataset) self.pose_vecs = utils.to_cuda(torch.nn.Embedding(num_images, 7, sparse=True)) self.pose_vecs.weight.data.copy_(self.train_dataset.get_pose_init()) self.optimizer_cam = torch.optim.SparseAdam(self.pose_vecs.parameters(), self.conf.get_float('train.learning_rate_cam')) self.start_epoch = 0 if is_continue: old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints') saved_model_state = torch.load( os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) self.model.load_state_dict(saved_model_state["model_state_dict"]) self.start_epoch = saved_model_state['epoch'] data = torch.load( os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth")) self.optimizer.load_state_dict(data["optimizer_state_dict"]) data = torch.load( os.path.join(old_checkpnts_dir, self.scheduler_params_subdir, str(kwargs['checkpoint']) + ".pth")) self.scheduler.load_state_dict(data["scheduler_state_dict"]) if self.train_cameras: data = torch.load( os.path.join(old_checkpnts_dir, self.optimizer_cam_params_subdir, str(kwargs['checkpoint']) + ".pth")) self.optimizer_cam.load_state_dict(data["optimizer_cam_state_dict"]) data = torch.load( os.path.join(old_checkpnts_dir, self.cam_params_subdir, str(kwargs['checkpoint']) + ".pth")) self.pose_vecs.load_state_dict(data["pose_vecs_state_dict"]) self.num_pixels = self.conf.get_int('train.num_pixels') self.total_pixels = self.train_dataset.total_pixels self.img_res = self.train_dataset.img_res self.n_batches = len(self.train_dataloader) self.plot_freq = self.conf.get_int('train.plot_freq') self.plot_conf = self.conf.get_config('plot') self.alpha_milestones = self.conf.get_list('train.alpha_milestones', default=[]) self.alpha_factor = self.conf.get_float('train.alpha_factor', default=0.0) for acc in self.alpha_milestones: if self.start_epoch > acc: self.loss.alpha = self.loss.alpha * self.alpha_factor
def run(self): print("training...") pbar = tqdm(range(self.start_epoch, self.nepochs + 1)) pbar.set_description(f'Training IDR',) for epoch in pbar: if epoch in self.alpha_milestones: self.loss.alpha = self.loss.alpha * self.alpha_factor if epoch % 100 == 0: self.save_checkpoints(epoch) if epoch % self.plot_freq == 0: self.model.eval() if self.train_cameras: self.pose_vecs.eval() self.train_dataset.change_sampling_idx(-1) indices, model_input, ground_truth = next(iter(self.plot_dataloader)) model_input["intrinsics"] = utils.to_cuda(model_input["intrinsics"]) model_input["uv"] = utils.to_cuda(model_input["uv"]) model_input["object_mask"] = utils.to_cuda(model_input["object_mask"]) if self.train_cameras: pose_input = self.pose_vecs(utils.to_cuda(indices)) model_input['pose'] = pose_input else: model_input['pose'] = utils.to_cuda(model_input['pose']) split = utils.split_input(model_input, self.total_pixels) res = [] for s in split: out = self.model(s) res.append({ 'points': out['points'].detach(), 'rgb_values': out['rgb_values'].detach(), 'network_object_mask': out['network_object_mask'].detach(), 'object_mask': out['object_mask'].detach() }) batch_size = ground_truth['rgb'].shape[0] model_outputs = utils.merge_output(res, self.total_pixels, batch_size) plt.plot(self.model, indices, model_outputs, model_input['pose'], ground_truth['rgb'], self.plots_dir, epoch, self.img_res, **self.plot_conf ) self.model.train() if self.train_cameras: self.pose_vecs.train() self.train_dataset.change_sampling_idx(self.num_pixels) for data_index, (indices, model_input, ground_truth) in enumerate(self.train_dataloader): model_input["intrinsics"] = utils.to_cuda(model_input["intrinsics"]) model_input["uv"] = utils.to_cuda(model_input["uv"]) model_input["object_mask"] = utils.to_cuda(model_input["object_mask"]) if self.train_cameras: pose_input = self.pose_vecs(utils.to_cuda(indices)) model_input['pose'] = pose_input else: model_input['pose'] = utils.to_cuda(model_input['pose']) model_outputs = self.model(model_input) loss_output = self.loss(model_outputs, ground_truth) loss = loss_output['loss'] self.optimizer.zero_grad() if self.train_cameras: self.optimizer_cam.zero_grad() loss.backward() self.optimizer.step() if self.train_cameras: self.optimizer_cam.step() pbar.set_postfix({ 'loss': loss.item(), 'rgb_loss': loss_output['rgb_loss'].item(), 'eikonal_loss': loss_output['eikonal_loss'].item(), 'mask_loss': loss_output['mask_loss'].item(), 'alpha': self.loss.alpha, 'lr': self.scheduler.get_lr()[0] }) self.scheduler.step()
def evaluate(**kwargs): torch.set_default_dtype(torch.float32) conf = ConfigFactory.parse_file(kwargs['conf']) exps_folder_name = kwargs['exps_folder_name'] evals_folder_name = kwargs['evals_folder_name'] eval_cameras = kwargs['eval_cameras'] eval_rendering = kwargs['eval_rendering'] expname = conf.get_string('train.expname') + kwargs['expname'] scene_id = kwargs['scene_id'] if kwargs['scene_id'] else conf.get_string( 'dataset.scene_id', default=None) if scene_id: expname = expname + '_{0}'.format(scene_id) if kwargs['timestamp'] == 'latest': if os.path.exists( os.path.join('../', kwargs['exps_folder_name'], expname)): timestamps = os.listdir( os.path.join('../', kwargs['exps_folder_name'], expname)) if (len(timestamps)) == 0: print('WRONG EXP FOLDER') exit() else: timestamp = sorted(timestamps)[-1] else: print('WRONG EXP FOLDER') exit() else: timestamp = kwargs['timestamp'] utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name)) expdir = os.path.join('../', exps_folder_name, expname) evaldir = os.path.join('../', evals_folder_name, expname) utils.mkdir_ifnotexists(evaldir) model = utils.get_class( conf.get_string('train.model_class'))(conf=conf.get_config('model')) if torch.cuda.is_available(): model.cuda() dataset_conf = conf.get_config('dataset') if kwargs['scene_id']: dataset_conf['scene_id'] = kwargs['scene_id'] eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))( eval_cameras, **dataset_conf) # settings for camera optimization scale_mat = eval_dataset.get_scale_mat() if eval_cameras: num_images = len(eval_dataset) pose_vecs = utils.to_cuda( torch.nn.Embedding(num_images, 7, sparse=True)) pose_vecs.weight.data.copy_(eval_dataset.get_pose_init()) gt_pose = eval_dataset.get_gt_pose() if eval_rendering: eval_dataloader = torch.utils.data.DataLoader( eval_dataset, batch_size=1, shuffle=False, collate_fn=eval_dataset.collate_fn) total_pixels = eval_dataset.total_pixels img_res = eval_dataset.img_res old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints') saved_model_state = torch.load( os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) model.load_state_dict(saved_model_state["model_state_dict"]) epoch = saved_model_state['epoch'] if eval_cameras: data = torch.load( os.path.join(old_checkpnts_dir, 'CamParameters', str(kwargs['checkpoint']) + ".pth")) pose_vecs.load_state_dict(data["pose_vecs_state_dict"]) #################################################################################################################### print("evaluating...") model.eval() if eval_cameras: pose_vecs.eval() with torch.no_grad(): if eval_cameras: gt_Rs = gt_pose[:, :3, :3].double() gt_ts = gt_pose[:, :3, 3].double() pred_Rs = rend_util.quat_to_rot( pose_vecs.weight.data[:, :4]).cpu().double() pred_ts = pose_vecs.weight.data[:, 4:].cpu().double() R_opt, t_opt, c_opt, R_fixed, t_fixed = get_cameras_accuracy( pred_Rs, gt_Rs, pred_ts, gt_ts) cams_transformation = np.eye(4, dtype=np.double) cams_transformation[:3, :3] = c_opt * R_opt cams_transformation[:3, 3] = t_opt mesh = plt.get_surface_mesh( sdf=lambda x: model.geometry_network(x)[:, 0], resolution=kwargs['resolution']) # Transform to world coordinates if eval_cameras: mesh.apply_transform(cams_transformation) else: mesh.apply_transform(scale_mat) # Taking the biggest connected component components = mesh.split(only_watertight=False) areas = np.array([c.area for c in components], dtype=np.float) mesh_clean = components[areas.argmax()] mesh_clean.export( '{0}/surface_world_coordinates_{1}.ply'.format(evaldir, epoch), 'ply') if eval_rendering: images_dir = '{0}/rendering'.format(evaldir) utils.mkdir_ifnotexists(images_dir) psnrs = [] for data_index, (indices, model_input, ground_truth) in enumerate(eval_dataloader): model_input["intrinsics"] = utils.to_cuda( model_input["intrinsics"]) model_input["uv"] = utils.to_cuda(model_input["uv"]) model_input["object_mask"] = utils.to_cuda( model_input["object_mask"]) if eval_cameras: pose_input = pose_vecs(utils.to_cuda(indices)) model_input['pose'] = pose_input else: model_input['pose'] = utils.to_cuda(model_input['pose']) split = utils.split_input(model_input, total_pixels) res = [] for s in split: out = model(s) res.append({ 'rgb_values': out['rgb_values'].detach(), }) batch_size = ground_truth['rgb'].shape[0] model_outputs = utils.merge_output(res, total_pixels, batch_size) rgb_eval = model_outputs['rgb_values'] rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3) rgb_eval = (rgb_eval + 1.) / 2. rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0] rgb_eval = rgb_eval.transpose(1, 2, 0) img = Image.fromarray((rgb_eval * 255).astype(np.uint8)) img.save('{0}/eval_{1}.png'.format(images_dir, '%03d' % indices[0])) rgb_gt = ground_truth['rgb'] rgb_gt = (rgb_gt + 1.) / 2. rgb_gt = plt.lin2img(rgb_gt, img_res).numpy()[0] rgb_gt = rgb_gt.transpose(1, 2, 0) mask = model_input['object_mask'] mask = plt.lin2img(mask.unsqueeze(-1), img_res).cpu().numpy()[0] mask = mask.transpose(1, 2, 0) rgb_eval_masked = rgb_eval * mask rgb_gt_masked = rgb_gt * mask psnr = calculate_psnr(rgb_eval_masked, rgb_gt_masked, mask) psnrs.append(psnr) psnrs = np.array(psnrs).astype(np.float64) print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}". format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scene_id))
def evaluate(**kwargs): torch.set_default_dtype(torch.float32) conf = ConfigFactory.parse_file(kwargs['conf']) exps_folder_name = kwargs['exps_folder_name'] evals_folder_name = kwargs['evals_folder_name'] timestamp = '2020' checkpoint = '2000' expname = conf.get_string('train.expname') geometry_id = kwargs['geometry_id'] appearance_id = kwargs['appearance_id'] utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name)) expdir_geometry = os.path.join('../', exps_folder_name, expname + '_{0}'.format(geometry_id)) expdir_appearance = os.path.join('../', exps_folder_name, expname + '_{0}'.format(appearance_id)) evaldir = os.path.join( '../', evals_folder_name, expname + '_{0}_{1}'.format(geometry_id, appearance_id)) utils.mkdir_ifnotexists(evaldir) model = utils.get_class( conf.get_string('train.model_class'))(conf=conf.get_config('model')) if torch.cuda.is_available(): model.cuda() # Load geometry network model old_checkpnts_dir = os.path.join(expdir_geometry, timestamp, 'checkpoints') saved_model_state = torch.load( os.path.join(old_checkpnts_dir, 'ModelParameters', checkpoint + ".pth")) model.load_state_dict(saved_model_state["model_state_dict"]) # Load rendering network model model_fake = utils.get_class( conf.get_string('train.model_class'))(conf=conf.get_config('model')) if torch.cuda.is_available(): model_fake.cuda() old_checkpnts_dir = os.path.join(expdir_appearance, timestamp, 'checkpoints') saved_model_state = torch.load( os.path.join(old_checkpnts_dir, 'ModelParameters', checkpoint + ".pth")) model_fake.load_state_dict(saved_model_state["model_state_dict"]) model.rendering_network = model_fake.rendering_network dataset_conf = conf.get_config('dataset') dataset_conf['scene_id'] = geometry_id eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))( False, **dataset_conf) eval_dataloader = torch.utils.data.DataLoader( eval_dataset, batch_size=1, shuffle=True, collate_fn=eval_dataset.collate_fn) total_pixels = eval_dataset.total_pixels img_res = eval_dataset.img_res #################################################################################################################### print("evaluating...") model.eval() gt_pose = utils.to_cuda(eval_dataset.get_gt_pose(scaled=True)) gt_quat = rend_util.rot_to_quat(gt_pose[:, :3, :3]) gt_pose_vec = torch.cat([gt_quat, gt_pose[:, :3, 3]], 1) indices_all = [11, 16, 34, 28, 11] pose = gt_pose_vec[indices_all, :] t_in = np.array([0, 2, 3, 5, 6]).astype(np.float32) n_inter = 5 t_out = np.linspace(t_in[0], t_in[-1], n_inter * t_in[-1]).astype(np.float32) scales = np.array([4.2, 4.2, 3.8, 3.8, 4.2]).astype(np.float32) s_new = CubicSpline(t_in, scales, bc_type='periodic') s_new = s_new(t_out) q_new = CubicSpline(t_in, pose[:, :4].detach().cpu().numpy(), bc_type='periodic') q_new = q_new(t_out) q_new = q_new / np.linalg.norm(q_new, 2, 1)[:, None] q_new = utils.to_cuda(torch.from_numpy(q_new)).float() images_dir = '{0}/novel_views_rendering'.format(evaldir) utils.mkdir_ifnotexists(images_dir) indices, model_input, ground_truth = next(iter(eval_dataloader)) for i, (new_q, scale) in enumerate(zip(q_new, s_new)): if torch.cuda.is_available(): torch.cuda.empty_cache() new_q = new_q.unsqueeze(0) new_t = -rend_util.quat_to_rot(new_q)[:, :, 2] * scale new_p = utils.to_cuda(torch.eye(4).float()).unsqueeze(0) new_p[:, :3, :3] = rend_util.quat_to_rot(new_q) new_p[:, :3, 3] = new_t sample = { "object_mask": utils.to_cuda(torch.zeros_like(model_input['object_mask'])).bool(), "uv": utils.to_cuda(model_input['uv']), "intrinsics": utils.to_cuda(model_input['intrinsics']), "pose": new_p } split = utils.split_input(sample, total_pixels) res = [] for s in split: out = model(s) res.append({ 'rgb_values': out['rgb_values'].detach(), }) batch_size = 1 model_outputs = utils.merge_output(res, total_pixels, batch_size) rgb_eval = model_outputs['rgb_values'] rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3) rgb_eval = (rgb_eval + 1.) / 2. rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0] rgb_eval = rgb_eval.transpose(1, 2, 0) img = Image.fromarray((rgb_eval * 255).astype(np.uint8)) img.save('{0}/eval_{1}.png'.format(images_dir, '%03d' % i))
def get_surface_high_res_mesh(sdf, resolution=100): # get low res mesh to sample point cloud mesh_low_res = get_surface_mesh(sdf, 100) recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0] recon_pc = utils.to_cuda(torch.from_numpy(recon_pc).float()) # Center and align the recon pc s_mean = recon_pc.mean(dim=0) s_cov = recon_pc - s_mean s_cov = torch.mm(s_cov.transpose(0, 1), s_cov) vecs = torch.eig(s_cov, True)[1].transpose(0, 1) if torch.det(vecs) < 0: vecs = torch.mm( utils.to_cuda(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]])).float(), vecs) helper = torch.bmm( vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1), (recon_pc - s_mean).unsqueeze(-1)).squeeze() grid_aligned = get_grid(helper.cpu(), resolution) grid_points = grid_aligned['grid_points'].cpu() s_mean = s_mean.cpu() vecs = vecs.cpu() g = [] for i, pnts in enumerate(torch.split(grid_points, 10000, dim=0)): g.append( torch.bmm( vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2), pnts.unsqueeze(-1)).squeeze() + s_mean) grid_points = torch.cat(g, dim=0) # MC to new grid points = grid_points z = [] for i, pnts in enumerate(torch.split(points, 10000, dim=0)): z.append(sdf(utils.to_cuda(pnts)).detach().cpu().numpy()) z = np.concatenate(z, axis=0) meshexport = None if (not (np.min(z) > 0 or np.max(z) < 0)): z = z.astype(np.float32) verts, faces, normals, values = measure.marching_cubes_lewiner( volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0], grid_aligned['xyz'][2].shape[0]).transpose( [1, 0, 2]), level=0, spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1])) verts = utils.to_cuda(torch.from_numpy(verts)).float() vecs = utils.to_cuda(vecs) verts = torch.bmm( vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2), verts.unsqueeze(-1)).squeeze() verts = verts.cpu() verts = (verts + grid_points[0]).numpy() meshexport = trimesh.Trimesh(verts, faces) return meshexport
def plot(model, indices, model_outputs, pose, rgb_gt, path, epoch, img_res, plot_nimgs, max_depth, resolution): # arrange data to plot batch_size, num_samples, _ = rgb_gt.shape network_object_mask = model_outputs['network_object_mask'] points = model_outputs['points'].reshape(batch_size, num_samples, 3) rgb_eval = model_outputs['rgb_values'] rgb_eval = rgb_eval.reshape(batch_size, num_samples, 3) depth = utils.to_cuda(torch.ones( batch_size * num_samples)).float() * max_depth depth[network_object_mask] = rend_util.get_depth( points, pose).reshape(-1)[network_object_mask] depth = depth.reshape(batch_size, num_samples, 1) network_object_mask = network_object_mask.reshape(batch_size, -1) cam_loc, cam_dir = rend_util.get_camera_for_plot(pose) # plot rendered images plot_images(rgb_eval, rgb_gt, path, epoch, plot_nimgs, img_res) # plot depth maps plot_depth_maps(depth, path, epoch, plot_nimgs, img_res) data = [] # plot surface surface_traces = get_surface_trace( path=path, epoch=epoch, sdf=lambda x: model.geometry_network(x)[:, 0], resolution=resolution) data.append(surface_traces[0]) # plot cameras locations for i, loc, dir in zip(indices, cam_loc, cam_dir): data.append( get_3D_quiver_trace(loc.unsqueeze(0), dir.unsqueeze(0), name='camera_{0}'.format(i))) for i, p, m in zip(indices, points, network_object_mask): p = p[m] sampling_idx = torch.randperm(p.shape[0])[:2048] p = p[sampling_idx, :] val = model.geometry_network(p) caption = ["sdf: {0} ".format(v[0].item()) for v in val] data.append( get_3D_scatter_trace(p, name='intersection_points_{0}'.format(i), caption=caption)) fig = go.Figure(data=data) scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False), yaxis=dict(range=[-3, 3], autorange=False), zaxis=dict(range=[-3, 3], autorange=False), aspectratio=dict(x=1, y=1, z=1)) fig.update_layout(scene=scene_dict, width=1400, height=1400, showlegend=True) filename = '{0}/surface_{1}.html'.format(path, epoch) offline.plot(fig, filename=filename, auto_open=False)
def get_eikonal_loss(self, grad_theta): if grad_theta.shape[0] == 0: return utils.to_cuda(torch.tensor(0.0)).float() eikonal_loss = ((grad_theta.norm(2, dim=1) - 1)**2).mean() return eikonal_loss
def forward(self, input): # Parse model input intrinsics = input["intrinsics"] uv = input["uv"] pose = input["pose"] object_mask = input["object_mask"].reshape(-1) ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics) batch_size, num_pixels, _ = ray_dirs.shape self.geometry_network.eval() with torch.no_grad(): points, network_object_mask, dists = self.ray_tracer(sdf=lambda x: self.geometry_network(x)[:, 0], cam_loc=cam_loc, object_mask=object_mask, ray_directions=ray_dirs) self.geometry_network.train() points = (cam_loc.unsqueeze(1) + dists.reshape(batch_size, num_pixels, 1) * ray_dirs).reshape(-1, 3) sdf_output = self.geometry_network(points)[:, 0:1] ray_dirs = ray_dirs.reshape(-1, 3) if self.training: surface_mask = network_object_mask & object_mask surface_points = points[surface_mask] surface_dists = dists[surface_mask].unsqueeze(-1) surface_ray_dirs = ray_dirs[surface_mask] surface_cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)[surface_mask] surface_output = sdf_output[surface_mask] N = surface_points.shape[0] # Sample points for the eikonal loss eik_bounding_box = self.object_bounding_sphere n_eik_points = batch_size * num_pixels // 2 eikonal_points = utils.to_cuda(torch.empty(n_eik_points, 3).uniform_(-eik_bounding_box, eik_bounding_box)) eikonal_pixel_points = points.clone() eikonal_pixel_points = eikonal_pixel_points.detach() eikonal_points = torch.cat([eikonal_points, eikonal_pixel_points], 0) points_all = torch.cat([surface_points, eikonal_points], dim=0) output = self.geometry_network(surface_points) surface_sdf_values = output[:N, 0:1].detach() g = self.geometry_network.gradient(points_all) surface_points_grad = g[:N, 0, :].clone().detach() grad_theta = g[N:, 0, :] differentiable_surface_points = self.sample_network(surface_output, surface_sdf_values, surface_points_grad, surface_dists, surface_cam_loc, surface_ray_dirs) else: surface_mask = network_object_mask differentiable_surface_points = points[surface_mask] grad_theta = None view = -ray_dirs[surface_mask] rgb_values = utils.to_cuda(torch.ones_like(points).float()) if differentiable_surface_points.shape[0] > 0: rgb_values[surface_mask] = self.get_rbg_value(differentiable_surface_points, view) output = { 'points': points, 'rgb_values': rgb_values, 'sdf_output': sdf_output, 'network_object_mask': network_object_mask, 'object_mask': object_mask, 'grad_theta': grad_theta } return output
def forward(self, sdf, cam_loc, object_mask, ray_directions): batch_size, num_pixels, _ = ray_directions.shape sphere_intersections, mask_intersect = rend_util.get_sphere_intersection( cam_loc, ray_directions, r=self.object_bounding_sphere) curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis = \ self.sphere_tracing(batch_size, num_pixels, sdf, cam_loc, ray_directions, mask_intersect, sphere_intersections) network_object_mask = (acc_start_dis < acc_end_dis) # The non convergent rays should be handled by the sampler sampler_mask = unfinished_mask_start sampler_net_obj_mask = utils.to_cuda( torch.zeros_like(sampler_mask).bool()) if sampler_mask.sum() > 0: sampler_min_max = utils.to_cuda( torch.zeros((batch_size, num_pixels, 2))) sampler_min_max.reshape(-1, 2)[sampler_mask, 0] = acc_start_dis[sampler_mask] sampler_min_max.reshape(-1, 2)[sampler_mask, 1] = acc_end_dis[sampler_mask] sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler( sdf, cam_loc, object_mask, ray_directions, sampler_min_max, sampler_mask) curr_start_points[sampler_mask] = sampler_pts[sampler_mask] acc_start_dis[sampler_mask] = sampler_dists[sampler_mask] network_object_mask[sampler_mask] = sampler_net_obj_mask[ sampler_mask] if self.verbose: print( '----------------------------------------------------------------' ) print('RayTracing: object = {0}/{1}, secant on {2}/{3}.'.format( network_object_mask.sum(), len(network_object_mask), sampler_net_obj_mask.sum(), sampler_mask.sum())) print( '----------------------------------------------------------------' ) if not self.training: return curr_start_points, \ network_object_mask, \ acc_start_dis ray_directions = ray_directions.reshape(-1, 3) mask_intersect = mask_intersect.reshape(-1) in_mask = ~network_object_mask & object_mask & ~sampler_mask out_mask = ~object_mask & ~sampler_mask mask_left_out = (in_mask | out_mask) & ~mask_intersect if mask_left_out.sum( ) > 0: # project the origin to the not intersect points on the sphere cam_left_out = cam_loc.unsqueeze(1).repeat( 1, num_pixels, 1).reshape(-1, 3)[mask_left_out] rays_left_out = ray_directions[mask_left_out] acc_start_dis[mask_left_out] = -torch.bmm( rays_left_out.view(-1, 1, 3), cam_left_out.view(-1, 3, 1)).squeeze() curr_start_points[mask_left_out] = cam_left_out + acc_start_dis[ mask_left_out].unsqueeze(1) * rays_left_out mask = (in_mask | out_mask) & mask_intersect if mask.sum() > 0: min_dis[network_object_mask & out_mask] = acc_start_dis[network_object_mask & out_mask] min_mask_points, min_mask_dist = self.minimal_sdf_points( num_pixels, sdf, cam_loc, ray_directions, mask, min_dis, max_dis) curr_start_points[mask] = min_mask_points acc_start_dis[mask] = min_mask_dist return curr_start_points, \ network_object_mask, \ acc_start_dis