Example #1
0
def get_camera_params(uv, pose, intrinsics):
    if pose.shape[1] == 7: #In case of quaternion vector representation
        cam_loc = pose[:, 4:]
        R = quat_to_rot(pose[:,:4])
        p = utils.to_cuda(torch.eye(4).repeat(pose.shape[0],1,1)).float()
        p[:, :3, :3] = R
        p[:, :3, 3] = cam_loc
    else: # In case of pose matrix representation
        cam_loc = pose[:, :3, 3]
        p = pose

    batch_size, num_samples, _ = uv.shape

    depth = utils.to_cuda(torch.ones((batch_size, num_samples)))
    x_cam = uv[:, :, 0].view(batch_size, -1)
    y_cam = uv[:, :, 1].view(batch_size, -1)
    z_cam = depth.view(batch_size, -1)

    pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)

    # permute for batch matrix product
    pixel_points_cam = pixel_points_cam.permute(0, 2, 1)

    world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]
    ray_dirs = world_coords - cam_loc[:, None, :]
    ray_dirs = F.normalize(ray_dirs, dim=2)

    return ray_dirs, cam_loc
Example #2
0
def interpolate(network, interval, experiment_directory, checkpoint, split_file, epoch, resolution, uniform_grid):

    with open(split_file, "r") as f:
        split = json.load(f)

    ds = utils.get_class(conf.get_string('train.dataset'))(split=split, dataset_path=conf.get_string('train.dataset_path'), with_normals=True)

    points_1, normals_1, index_1 = ds[0]
    points_2, normals_2, index_2 = ds[1]

    pnts = utils.to_cuda(torch.cat([points_1, points_2], dim=0))

    name_1 = str.join('_', ds.get_info(0))
    name_2 = str.join('_', ds.get_info(1))

    name = name_1 + '_and_' + name_2

    utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'interpolate'))
    utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'interpolate', str(checkpoint)))
    utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'interpolate', str(checkpoint), name))

    my_path = os.path.join(experiment_directory, 'interpolate', str(checkpoint), name)

    latent_1 = optimize_latent(utils.to_cuda(points_1), utils.to_cuda(normals_1), conf, 800, network, 5e-3)
    latent_2 = optimize_latent(utils.to_cuda(points_2), utils.to_cuda(normals_2), conf, 800, network, 5e-3)

    pnts = torch.cat([latent_1.repeat(pnts.shape[0], 1), pnts], dim=-1)

    with torch.no_grad():
        network.eval()

        for alpha in np.linspace(0,1, interval):

            latent = (latent_1 * (1-alpha)) + (latent_2 * alpha)

            plt.plot_surface(with_points=False,
                             points=pnts,
                             decoder=network,
                             latent=latent,
                             path=my_path,
                             epoch=epoch,
                             shapename=str(alpha),
                             resolution=resolution,
                             mc_value=0,
                             is_uniform_grid=uniform_grid,
                             verbose=False,
                             save_html=False,
                             save_ply=True,
                             overwrite=True,
                             connected=True)
Example #3
0
def lift(x, y, z, intrinsics):
    # parse intrinsics
    intrinsics = utils.to_cuda(intrinsics)
    fx = intrinsics[:, 0, 0]
    fy = intrinsics[:, 1, 1]
    cx = intrinsics[:, 0, 2]
    cy = intrinsics[:, 1, 2]
    sk = intrinsics[:, 0, 1]

    x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
    y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z

    # homogeneous
    return torch.stack((x_lift, y_lift, z, utils.to_cuda(torch.ones_like(z))), dim=-1)
Example #4
0
    def get_pose_init(self):
        # get noisy initializations obtained with the linear method
        cam_file = '{0}/cameras_linear_init.npz'.format(self.instance_dir)
        camera_dict = np.load(cam_file)
        scale_mats = [
            camera_dict['scale_mat_%d' % idx].astype(np.float32)
            for idx in self.views
        ]
        world_mats = [
            camera_dict['world_mat_%d' % idx].astype(np.float32)
            for idx in self.views
        ]

        init_pose = []
        for scale_mat, world_mat in zip(scale_mats, world_mats):
            P = world_mat @ scale_mat
            P = P[:3, :4]
            _, pose = rend_util.load_K_Rt_from_P(None, P)
            init_pose.append(pose)
        init_pose = utils.to_cuda(
            torch.cat([
                torch.Tensor(pose).float().unsqueeze(0) for pose in init_pose
            ], 0))
        init_quat = rend_util.rot_to_quat(init_pose[:, :3, :3])
        init_quat = torch.cat([init_quat, init_pose[:, :3, 3]], 1)

        return init_quat
Example #5
0
    def plot_validation_shapes(self, epoch, with_cuts=False):
        # plot network validation shapes
        with torch.no_grad():

            print('plot validation epoch: ', epoch)

            self.network.eval()
            pnts, normals, idx = next(iter(self.eval_dataloader))
            pnts = utils.to_cuda(pnts)

            pnts = self.add_latent(pnts, idx)
            latent = self.lat_vecs[idx[0]]

            shapename = str.join('_', self.ds.get_info(idx))

            plot_surface(with_points=True,
                         points=pnts,
                         decoder=self.network,
                         latent=latent,
                         path=self.plots_dir,
                         epoch=epoch,
                         shapename=shapename,
                         **self.conf.get_config('plot'))

            if with_cuts:
                plot_cuts(points=pnts,
                          decoder=self.network,
                          latent=latent,
                          path=self.plots_dir,
                          epoch=epoch,
                          near_zero=False)
Example #6
0
    def minimal_sdf_points(self, num_pixels, sdf, cam_loc, ray_directions,
                           mask, min_dis, max_dis):
        ''' Find points with minimal SDF value on rays for P_out pixels '''

        n_mask_points = mask.sum()

        n = self.n_steps
        # steps = torch.linspace(0.0, 1.0,n).cuda()
        steps = utils.to_cuda(torch.empty(n).uniform_(0.0, 1.0))
        mask_max_dis = max_dis[mask].unsqueeze(-1)
        mask_min_dis = min_dis[mask].unsqueeze(-1)
        steps = steps.unsqueeze(0).repeat(
            n_mask_points, 1) * (mask_max_dis - mask_min_dis) + mask_min_dis

        mask_points = cam_loc.unsqueeze(1).repeat(1, num_pixels,
                                                  1).reshape(-1, 3)[mask]
        mask_rays = ray_directions[mask, :]

        mask_points_all = mask_points.unsqueeze(1).repeat(
            1, n,
            1) + steps.unsqueeze(-1) * mask_rays.unsqueeze(1).repeat(1, n, 1)
        points = mask_points_all.reshape(-1, 3)

        mask_sdf_all = []
        for pnts in torch.split(points, 100000, dim=0):
            mask_sdf_all.append(sdf(pnts))

        mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n)
        min_vals, min_idx = mask_sdf_all.min(-1)
        min_mask_points = mask_points_all.reshape(
            -1, n, 3)[torch.arange(0, n_mask_points), min_idx]
        min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points),
                                             min_idx]

        return min_mask_points, min_mask_dist
Example #7
0
def get_depth(points, pose):
    ''' Retruns depth from 3D points according to camera pose '''
    batch_size, num_samples, _ = points.shape
    if pose.shape[1] == 7:  # In case of quaternion vector representation
        cam_loc = pose[:, 4:]
        R = quat_to_rot(pose[:, :4])
        pose = utils.to_cuda(torch.eye(4).unsqueeze(0).repeat(batch_size, 1, 1)).float()
        pose[:, :3, 3] = cam_loc
        pose[:, :3, :3] = R

    points_hom = torch.cat((points, utils.to_cuda(torch.ones((batch_size, num_samples, 1)))), dim=2)

    # permute for batch matrix product
    points_hom = points_hom.permute(0, 2, 1)

    points_cam = torch.inverse(pose).bmm(points_hom)
    depth = points_cam[:, 2, :][:, :, None]
    return depth
Example #8
0
    def get_rgb_loss(self, rgb_values, rgb_gt, network_object_mask,
                     object_mask):
        if (network_object_mask & object_mask).sum() == 0:
            return utils.to_cuda(torch.tensor(0.0)).float()

        rgb_values = rgb_values[network_object_mask & object_mask]
        rgb_gt = rgb_gt.reshape(-1, 3)[network_object_mask & object_mask]
        rgb_loss = self.l1_loss(rgb_values, rgb_gt) / float(
            object_mask.shape[0])
        return rgb_loss
Example #9
0
 def get_mask_loss(self, sdf_output, network_object_mask, object_mask):
     mask = ~(network_object_mask & object_mask)
     if mask.sum() == 0:
         return utils.to_cuda(torch.tensor(0.0)).float()
     sdf_pred = -self.alpha * sdf_output[mask]
     gt = object_mask[mask].float()
     mask_loss = (1 / self.alpha) * F.binary_cross_entropy_with_logits(
         sdf_pred.squeeze(), gt, reduction='sum') / float(
             object_mask.shape[0])
     return mask_loss
Example #10
0
    def add_latent(self, points, indices):
        batch_size, num_of_points, dim = points.shape
        points = points.reshape(batch_size * num_of_points, dim)
        latent_inputs = utils.to_cuda(torch.zeros(0))

        for ind in indices.numpy():
            latent_ind = self.lat_vecs[ind]
            latent_repeat = latent_ind.expand(num_of_points, -1)
            latent_inputs = torch.cat([latent_inputs, latent_repeat], 0)
        points = torch.cat([latent_inputs, points], 1)
        return points
Example #11
0
def get_grid_uniform(resolution):
    x = np.linspace(-1.2, 1.2, resolution)
    y = x
    z = x

    xx, yy, zz = np.meshgrid(x, y, z)
    grid_points = utils.to_cuda(torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float))

    return {"grid_points": grid_points,
            "shortest_axis_length": 2.4,
            "xyz": [x, y, z],
            "shortest_axis_index": 0}
Example #12
0
def get_grid(points, resolution):
    eps = 0.2
    input_min = torch.min(points, dim=0)[0].squeeze().numpy()
    input_max = torch.max(points, dim=0)[0].squeeze().numpy()

    bounding_box = input_max - input_min
    shortest_axis = np.argmin(bounding_box)
    if (shortest_axis == 0):
        x = np.linspace(input_min[shortest_axis] - eps,
                        input_max[shortest_axis] + eps, resolution)
        length = np.max(x) - np.min(x)
        y = np.arange(input_min[1] - eps,
                      input_max[1] + length / (x.shape[0] - 1) + eps,
                      length / (x.shape[0] - 1))
        z = np.arange(input_min[2] - eps,
                      input_max[2] + length / (x.shape[0] - 1) + eps,
                      length / (x.shape[0] - 1))
    elif (shortest_axis == 1):
        y = np.linspace(input_min[shortest_axis] - eps,
                        input_max[shortest_axis] + eps, resolution)
        length = np.max(y) - np.min(y)
        x = np.arange(input_min[0] - eps,
                      input_max[0] + length / (y.shape[0] - 1) + eps,
                      length / (y.shape[0] - 1))
        z = np.arange(input_min[2] - eps,
                      input_max[2] + length / (y.shape[0] - 1) + eps,
                      length / (y.shape[0] - 1))
    elif (shortest_axis == 2):
        z = np.linspace(input_min[shortest_axis] - eps,
                        input_max[shortest_axis] + eps, resolution)
        length = np.max(z) - np.min(z)
        x = np.arange(input_min[0] - eps,
                      input_max[0] + length / (z.shape[0] - 1) + eps,
                      length / (z.shape[0] - 1))
        y = np.arange(input_min[1] - eps,
                      input_max[1] + length / (z.shape[0] - 1) + eps,
                      length / (z.shape[0] - 1))

    xx, yy, zz = np.meshgrid(x, y, z)
    grid_points = utils.to_cuda(
        torch.tensor(np.vstack([xx.ravel(), yy.ravel(),
                                zz.ravel()]).T,
                     dtype=torch.float))
    return {
        "grid_points": grid_points,
        "shortest_axis_length": length,
        "xyz": [x, y, z],
        "shortest_axis_index": shortest_axis
    }
Example #13
0
def plot_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res):
    ground_true = (utils.to_cuda(ground_true) + 1.) / 2.
    rgb_points = (rgb_points + 1.) / 2.

    output_vs_gt = torch.cat((rgb_points, ground_true), dim=0)
    output_vs_gt_plot = lin2img(output_vs_gt, img_res)

    tensor = torchvision.utils.make_grid(
        output_vs_gt_plot, scale_each=False, normalize=False,
        nrow=plot_nrow).cpu().detach().numpy()

    tensor = tensor.transpose(1, 2, 0)
    scale_factor = 255
    tensor = (tensor * scale_factor).astype(np.uint8)

    img = Image.fromarray(tensor)
    img.save('{0}/rendering_{1}.png'.format(path, epoch))
Example #14
0
def quat_to_rot(q):
    batch_size, _ = q.shape
    q = F.normalize(q, dim=1)
    R = utils.to_cuda(torch.ones((batch_size, 3,3)))
    qr=q[:,0]
    qi = q[:, 1]
    qj = q[:, 2]
    qk = q[:, 3]
    R[:, 0, 0]=1-2 * (qj**2 + qk**2)
    R[:, 0, 1] = 2 * (qj *qi -qk*qr)
    R[:, 0, 2] = 2 * (qi * qk + qr * qj)
    R[:, 1, 0] = 2 * (qj * qi + qk * qr)
    R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
    R[:, 1, 2] = 2*(qj*qk - qi*qr)
    R[:, 2, 0] = 2 * (qk * qi-qj * qr)
    R[:, 2, 1] = 2 * (qj*qk + qi*qr)
    R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
    return R
Example #15
0
def rot_to_quat(R):
    batch_size, _,_ = R.shape
    q = utils.to_cuda(torch.ones((batch_size, 4)))

    R00 = R[:, 0,0]
    R01 = R[:, 0, 1]
    R02 = R[:, 0, 2]
    R10 = R[:, 1, 0]
    R11 = R[:, 1, 1]
    R12 = R[:, 1, 2]
    R20 = R[:, 2, 0]
    R21 = R[:, 2, 1]
    R22 = R[:, 2, 2]

    q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2
    q[:, 1]=(R21-R12)/(4*q[:,0])
    q[:, 2] = (R02 - R20) / (4 * q[:, 0])
    q[:, 3] = (R10 - R01) / (4 * q[:, 0])
    return q
Example #16
0
    def forward(self, model_outputs, ground_truth):
        rgb_gt = utils.to_cuda(ground_truth['rgb'])
        network_object_mask = model_outputs['network_object_mask']
        object_mask = model_outputs['object_mask']

        rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt,
                                     network_object_mask, object_mask)
        mask_loss = self.get_mask_loss(model_outputs['sdf_output'],
                                       network_object_mask, object_mask)
        eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])

        loss = rgb_loss + \
               self.eikonal_weight * eikonal_loss + \
               self.mask_weight * mask_loss

        return {
            'loss': loss,
            'rgb_loss': rgb_loss,
            'eikonal_loss': eikonal_loss,
            'mask_loss': mask_loss,
        }
Example #17
0
def get_sphere_intersection(cam_loc, ray_directions, r = 1.0):
    # Input: n_images x 4 x 4 ; n_images x n_rays x 3
    # Output: n_images * n_rays x 2 (close and far) ; n_images * n_rays

    n_imgs, n_pix, _ = ray_directions.shape

    cam_loc = cam_loc.unsqueeze(-1)
    ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze()
    under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2,1) ** 2 - r ** 2)

    under_sqrt = under_sqrt.reshape(-1)
    mask_intersect = under_sqrt > 0

    sphere_intersections = utils.to_cuda(torch.zeros(n_imgs * n_pix, 2)).float()
    sphere_intersections[mask_intersect] = torch.sqrt(under_sqrt[mask_intersect]).unsqueeze(-1) * utils.to_cuda(torch.Tensor([-1, 1])).float()
    sphere_intersections[mask_intersect] -= ray_cam_dot.reshape(-1)[mask_intersect].unsqueeze(-1)

    sphere_intersections = sphere_intersections.reshape(n_imgs, n_pix, 2)
    sphere_intersections = sphere_intersections.clamp_min(0.0)
    mask_intersect = mask_intersect.reshape(n_imgs, n_pix)

    return sphere_intersections, mask_intersect
Example #18
0
def plot_cuts_axis(points, decoder, latent, path, epoch, near_zero, axis, file_name_sep='/'):
    onedim_cut = np.linspace(-1.0, 1.0, 200)
    xx, yy = np.meshgrid(onedim_cut, onedim_cut)
    xx = xx.ravel()
    yy = yy.ravel()
    min_axis = points[:, axis].min(dim=0)[0].item()
    max_axis = points[:, axis].max(dim=0)[0].item()
    mask = np.zeros(3)
    mask[axis] = 1.0
    if axis == 0:
        position_cut = np.vstack(([np.zeros(xx.shape[0]), xx, yy]))
    elif axis == 1:
        position_cut = np.vstack(([xx, np.zeros(xx.shape[0]), yy]))
    elif axis == 2:
        position_cut = np.vstack(([xx, yy, np.zeros(xx.shape[0])]))
    position_cut = [position_cut + i * mask.reshape(-1, 1) for i in np.linspace(min_axis - 0.1, max_axis + 0.1, 50)]
    for index, pos in enumerate(position_cut):
        # fig = tools.make_subplots(rows=1, cols=1)

        field_input = utils.to_cuda(torch.tensor(pos.T, dtype=torch.float))
        z = []
        for i, pnts in enumerate(torch.split(field_input, 10000, dim=0)):
            if not latent is None:
                pnts = torch.cat([latent.expand(pnts.shape[0], -1), pnts], dim=1)
            z.append(decoder(pnts).detach().cpu().numpy())
        z = np.concatenate(z, axis=0)

        if near_zero:
            if np.min(z) < -1.0e-5:
                start = -0.1
            else:
                start = 0.0
            trace1 = go.Contour(x=onedim_cut,
                                y=onedim_cut,
                                z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]),
                                name='axis {0} = {1}'.format(axis, pos[axis, 0]),  # colorbar=dict(len=0.4, y=0.8),
                                autocontour=False,
                                contours=dict(
                                    start=start,
                                    end=0.1,
                                    size=0.01
                                )
                                # ),colorbar = {'dtick': 0.05}
                                )
        else:
            trace1 = go.Contour(x=onedim_cut,
                                y=onedim_cut,
                                z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]),
                                name='axis {0} = {1}'.format(axis, pos[axis, 0]),  # colorbar=dict(len=0.4, y=0.8),
                                autocontour=True,
                                ncontours=70
                                # contours=dict(
                                #      start=-0.001,
                                #      end=0.001,
                                #      size=0.00001
                                #      )
                                # ),colorbar = {'dtick': 0.05}
                                )

        layout = go.Layout(width=1200, height=1200, scene=dict(xaxis=dict(range=[-1, 1], autorange=False),
                                                               yaxis=dict(range=[-1, 1], autorange=False),
                                                               aspectratio=dict(x=1, y=1)),
                           title=dict(text='axis {0} = {1}'.format(axis, pos[axis, 0])))
        # fig['layout']['xaxis2'].update(range=[-1, 1])
        # fig['layout']['yaxis2'].update(range=[-1, 1], scaleanchor="x2", scaleratio=1)

        filename = '{0}{1}cutsaxis_{2}_{3}_{4}.html'.format(path, file_name_sep, axis, epoch, index)
        fig1 = go.Figure(data=[trace1], layout=layout)
        offline.plot(fig1, filename=filename, auto_open=False)
Example #19
0
    def run(self):

        print("running")

        for epoch in range(self.startepoch, self.nepochs + 1):

            if epoch % self.conf.get_int('train.checkpoint_frequency') == 0:
                self.save_checkpoints(epoch)
                self.plot_validation_shapes(epoch)

            # change back to train mode
            self.network.train()
            self.adjust_learning_rate(epoch)

            # start epoch
            before_epoch = time()
            for data_index, (mnfld_pnts, normals,
                             indices) in enumerate(self.train_dataloader):

                mnfld_pnts = utils.to_cuda(mnfld_pnts)

                if self.with_normals:
                    normals = utils.to_cuda(normals)

                nonmnfld_pnts = self.sampler.get_points(mnfld_pnts)

                mnfld_pnts = self.add_latent(mnfld_pnts, indices)
                nonmnfld_pnts = self.add_latent(nonmnfld_pnts, indices)

                # forward pass

                mnfld_pnts.requires_grad_()
                nonmnfld_pnts.requires_grad_()

                mnfld_pred = self.network(mnfld_pnts)
                nonmnfld_pred = self.network(nonmnfld_pnts)

                mnfld_grad = gradient(mnfld_pnts, mnfld_pred)
                nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred)

                # manifold loss

                mnfld_loss = (mnfld_pred.abs()).mean()

                # eikonal loss

                grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1)**2).mean()

                loss = mnfld_loss + self.grad_lambda * grad_loss

                # normals loss
                if self.with_normals:
                    normals = normals.view(-1, 3)
                    normals_loss = ((mnfld_grad - normals).abs()).norm(
                        2, dim=1).mean()
                    loss = loss + self.normals_lambda * normals_loss
                else:
                    normals_loss = torch.zeros(1)

                # latent loss

                latent_loss = self.latent_size_reg(utils.to_cuda(indices))

                loss = loss + self.latent_lambda * latent_loss

                # back propagation

                self.optimizer.zero_grad()

                loss.backward()

                self.optimizer.step()

                # print status
                if data_index % self.conf.get_int(
                        'train.status_frequency') == 0:
                    print(
                        'Train Epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\tManifold loss: {:.6f}'
                        '\tGrad loss: {:.6f}\tLatent loss: {:.6f}\tNormals Loss: {:.6f}'
                        .format(epoch, data_index * self.batch_size,
                                len(self.ds),
                                100. * data_index / len(self.train_dataloader),
                                loss.item(), mnfld_loss.item(),
                                grad_loss.item(), latent_loss.item(),
                                normals_loss.item()))

            after_epoch = time()
            print('epoch time {0}'.format(str(after_epoch - before_epoch)))
Example #20
0
def plot_cuts(points, decoder, path, epoch, near_zero, latent=None):
    onedim_cut = np.linspace(-1, 1, 200)
    xx, yy = np.meshgrid(onedim_cut, onedim_cut)
    xx = xx.ravel()
    yy = yy.ravel()
    min_y = points[:, -2].min(dim=0)[0].item()
    max_y = points[:, -2].max(dim=0)[0].item()
    position_cut = np.vstack(([xx, np.zeros(xx.shape[0]), yy]))
    position_cut = [
        position_cut + np.array([0., i, 0.]).reshape(-1, 1)
        for i in np.linspace(min_y - 0.1, max_y + 0.1, 10)
    ]
    for index, pos in enumerate(position_cut):
        #fig = tools.make_subplots(rows=1, cols=1)

        field_input = utils.to_cuda(torch.tensor(pos.T, dtype=torch.float))
        z = []
        for i, pnts in enumerate(torch.split(field_input, 1000, dim=-1)):
            input_ = pnts
            if (not latent is None):
                input_ = torch.cat([latent.expand(pnts.shape[0], -1), pnts],
                                   dim=1)
            z.append(decoder(input_).detach().cpu().numpy())
        z = np.concatenate(z, axis=0)

        if (near_zero):
            trace1 = go.Contour(
                x=onedim_cut,
                y=onedim_cut,
                z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]),
                name='y = {0}'.format(
                    pos[1, 0]),  # colorbar=dict(len=0.4, y=0.8),
                autocontour=False,
                contours=dict(start=-0.001, end=0.001, size=0.00001)
                # ),colorbar = {'dtick': 0.05}
            )
        else:
            trace1 = go.Contour(
                x=onedim_cut,
                y=onedim_cut,
                z=z.reshape(onedim_cut.shape[0], onedim_cut.shape[0]),
                name='y = {0}'.format(
                    pos[1, 0]),  # colorbar=dict(len=0.4, y=0.8),
                autocontour=True,
                # contours=dict(
                #      start=-0.001,
                #      end=0.001,
                #      size=0.00001
                #      )
                # ),colorbar = {'dtick': 0.05}
            )

        layout = go.Layout(width=1200,
                           height=1200,
                           scene=dict(xaxis=dict(range=[-1, 1],
                                                 autorange=False),
                                      yaxis=dict(range=[-1, 1],
                                                 autorange=False),
                                      aspectratio=dict(x=1, y=1)),
                           title=dict(text='y = {0}'.format(pos[1, 0])))
        # fig['layout']['xaxis2'].update(range=[-1, 1])
        # fig['layout']['yaxis2'].update(range=[-1, 1], scaleanchor="x2", scaleratio=1)

        filename = '{0}/cuts{1}_{2}.html'.format(path, epoch, index)
        fig1 = go.Figure(data=[trace1], layout=layout)
        offline.plot(fig1, filename=filename, auto_open=False)
Example #21
0
    def __init__(self, **kwargs):

        # config setting

        self.home_dir = os.path.abspath(os.pardir)

        if type(kwargs['conf']) == str:
            self.conf_filename = os.path.abspath(kwargs['conf'])
            self.conf = ConfigFactory.parse_file(self.conf_filename)
        else:
            self.conf = kwargs['conf']

        self.expname = kwargs['expname']

        # GPU settings

        self.GPU_INDEX = kwargs['gpu_index']

        if not self.GPU_INDEX == 'ignore':
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        self.num_of_gpus = torch.cuda.device_count()

        # settings for loading an existing experiment

        if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
            if os.path.exists(os.path.join(self.home_dir, 'exps',
                                           self.expname)):
                timestamps = os.listdir(
                    os.path.join(self.home_dir, 'exps', self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue']

        self.exps_folder_name = 'exps'

        utils.mkdir_ifnotexists(
            utils.concat_home_dir(
                os.path.join(self.home_dir, self.exps_folder_name)))

        self.expdir = utils.concat_home_dir(
            os.path.join(self.home_dir, self.exps_folder_name, self.expname))
        utils.mkdir_ifnotexists(self.expdir)

        if is_continue:
            self.timestamp = timestamp
        else:
            self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())

        self.cur_exp_dir = self.timestamp
        utils.mkdir_ifnotexists(os.path.join(self.expdir, self.cur_exp_dir))

        self.plots_dir = os.path.join(self.expdir, self.cur_exp_dir, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir,
                                             'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir,
                                             'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"
        self.latent_codes_subdir = "LatentCodes"

        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.model_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.latent_codes_subdir))

        self.nepochs = kwargs['nepochs']

        self.batch_size = kwargs['batch_size']

        if self.num_of_gpus > 0:
            self.batch_size *= self.num_of_gpus

        self.parallel = self.num_of_gpus > 1

        self.global_sigma = self.conf.get_float(
            'network.sampler.properties.global_sigma')
        self.local_sigma = self.conf.get_float(
            'network.sampler.properties.local_sigma')
        self.sampler = Sampler.get_sampler(
            self.conf.get_string('network.sampler.sampler_type'))(
                self.global_sigma, self.local_sigma)

        train_split_file = os.path.abspath(kwargs['split_file'])
        print(f'Loading split file {train_split_file}')
        with open(train_split_file, "r") as f:
            train_split = json.load(f)
        print(f'Size of the split: {len(train_split)} samples')

        self.d_in = self.conf.get_int('train.d_in')

        # latent preprocessing

        self.latent_size = self.conf.get_int('train.latent_size')

        self.latent_lambda = self.conf.get_float('network.loss.latent_lambda')
        self.grad_lambda = self.conf.get_float('network.loss.lambda')
        self.normals_lambda = self.conf.get_float(
            'network.loss.normals_lambda')

        self.with_normals = self.normals_lambda > 0

        self.ds = utils.get_class(self.conf.get_string('train.dataset'))(
            split=train_split,
            with_normals=self.with_normals,
            dataset_path=self.conf.get_string('train.dataset_path'),
            points_batch=kwargs['points_batch'],
        )

        self.num_scenes = len(self.ds)

        self.train_dataloader = torch.utils.data.DataLoader(
            self.ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=kwargs['threads'],
            drop_last=True,
            pin_memory=True)
        self.eval_dataloader = torch.utils.data.DataLoader(self.ds,
                                                           batch_size=1,
                                                           shuffle=False,
                                                           num_workers=0,
                                                           drop_last=True)

        self.network = utils.get_class(
            self.conf.get_string('train.network_class'))(
                d_in=(self.d_in + self.latent_size),
                **self.conf.get_config('network.inputs'))

        if self.parallel:
            self.network = torch.nn.DataParallel(self.network)

        if torch.cuda.is_available():
            self.network.cuda()

        self.lr_schedules = self.get_learning_rate_schedules(
            self.conf.get_list('train.learning_rate_schedule'))
        self.weight_decay = self.conf.get_float('train.weight_decay')

        # optimizer and latent settings

        self.startepoch = 0

        self.lat_vecs = utils.to_cuda(
            torch.zeros(self.num_scenes, self.latent_size))
        self.lat_vecs.requires_grad_()

        self.optimizer = torch.optim.Adam([
            {
                "params": self.network.parameters(),
                "lr": self.lr_schedules[0].get_learning_rate(0),
                "weight_decay": self.weight_decay
            },
            {
                "params": self.lat_vecs,
                "lr": self.lr_schedules[1].get_learning_rate(0)
            },
        ])

        # if continue load checkpoints

        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp,
                                             'checkpoints')

            data = torch.load(
                os.path.join(old_checkpnts_dir, self.latent_codes_subdir,
                             str(kwargs['checkpoint']) + '.pth'))

            self.lat_vecs = utils.to_cuda(data["latent_codes"])

            saved_model_state = torch.load(
                os.path.join(old_checkpnts_dir, 'ModelParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.network.load_state_dict(saved_model_state["model_state_dict"])

            data = torch.load(
                os.path.join(old_checkpnts_dir, 'OptimizerParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])
            self.startepoch = saved_model_state['epoch']
Example #22
0
    def __init__(self,**kwargs):
        torch.set_default_dtype(torch.float32)
        torch.set_num_threads(1)

        self.conf = ConfigFactory.parse_file(kwargs['conf'])
        self.batch_size = kwargs['batch_size']
        self.nepochs = kwargs['nepochs']
        self.exps_folder_name = kwargs['exps_folder_name']
        self.GPU_INDEX = kwargs['gpu_index']
        self.train_cameras = kwargs['train_cameras']

        self.expname = self.conf.get_string('train.expname') + kwargs['expname']
        scene_id = kwargs['scene_id'] if kwargs['scene_id'] else self.conf.get_string('dataset.scene_id', default=None)
        if scene_id:
            self.expname = self.expname + '_{0}'.format(scene_id)

        if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
            if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)):
                timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue']

        utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name))
        self.expdir = os.path.join('../', self.exps_folder_name, self.expname)
        utils.mkdir_ifnotexists(self.expdir)
        self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
        utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp))

        self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        # create checkpoints dirs
        self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)
        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"
        self.scheduler_params_subdir = "SchedulerParameters"

        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir))
        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.scheduler_params_subdir))

        if self.train_cameras:
            self.optimizer_cam_params_subdir = "OptimizerCamParameters"
            self.cam_params_subdir = "CamParameters"

            utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_cam_params_subdir))
            utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.cam_params_subdir))

        os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.expdir, self.timestamp, 'runconf.conf')))

        if (not self.GPU_INDEX == 'ignore'):
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        print('shell command : {0}'.format(' '.join(sys.argv)))

        print('Loading data ...')

        dataset_conf = self.conf.get_config('dataset')
        if kwargs['scene_id']:
            dataset_conf['scene_id'] = kwargs['scene_id']

        self.train_dataset = utils.get_class(self.conf.get_string('train.dataset_class'))(self.train_cameras,
                                                                                          **dataset_conf)

        print('Finish loading data ...')

        self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset,
                                                            batch_size=self.batch_size,
                                                            shuffle=True,
                                                            collate_fn=self.train_dataset.collate_fn
                                                            )
        self.plot_dataloader = torch.utils.data.DataLoader(self.train_dataset,
                                                           batch_size=self.conf.get_int('plot.plot_nimgs'),
                                                           shuffle=True,
                                                           collate_fn=self.train_dataset.collate_fn
                                                           )

        self.model = utils.get_class(self.conf.get_string('train.model_class'))(conf=self.conf.get_config('model'))
        if torch.cuda.is_available():
            self.model.cuda()

        self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(**self.conf.get_config('loss'))

        self.lr = self.conf.get_float('train.learning_rate')
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.sched_milestones = self.conf.get_list('train.sched_milestones', default=[])
        self.sched_factor = self.conf.get_float('train.sched_factor', default=0.0)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, self.sched_milestones, gamma=self.sched_factor)

        # settings for camera optimization
        if self.train_cameras:
            num_images = len(self.train_dataset)
            self.pose_vecs = utils.to_cuda(torch.nn.Embedding(num_images, 7, sparse=True))
            self.pose_vecs.weight.data.copy_(self.train_dataset.get_pose_init())

            self.optimizer_cam = torch.optim.SparseAdam(self.pose_vecs.parameters(), self.conf.get_float('train.learning_rate_cam'))

        self.start_epoch = 0
        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints')

            saved_model_state = torch.load(
                os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
            self.model.load_state_dict(saved_model_state["model_state_dict"])
            self.start_epoch = saved_model_state['epoch']

            data = torch.load(
                os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])

            data = torch.load(
                os.path.join(old_checkpnts_dir, self.scheduler_params_subdir, str(kwargs['checkpoint']) + ".pth"))
            self.scheduler.load_state_dict(data["scheduler_state_dict"])

            if self.train_cameras:
                data = torch.load(
                    os.path.join(old_checkpnts_dir, self.optimizer_cam_params_subdir, str(kwargs['checkpoint']) + ".pth"))
                self.optimizer_cam.load_state_dict(data["optimizer_cam_state_dict"])

                data = torch.load(
                    os.path.join(old_checkpnts_dir, self.cam_params_subdir, str(kwargs['checkpoint']) + ".pth"))
                self.pose_vecs.load_state_dict(data["pose_vecs_state_dict"])

        self.num_pixels = self.conf.get_int('train.num_pixels')
        self.total_pixels = self.train_dataset.total_pixels
        self.img_res = self.train_dataset.img_res
        self.n_batches = len(self.train_dataloader)
        self.plot_freq = self.conf.get_int('train.plot_freq')
        self.plot_conf = self.conf.get_config('plot')

        self.alpha_milestones = self.conf.get_list('train.alpha_milestones', default=[])
        self.alpha_factor = self.conf.get_float('train.alpha_factor', default=0.0)
        for acc in self.alpha_milestones:
            if self.start_epoch > acc:
                self.loss.alpha = self.loss.alpha * self.alpha_factor
Example #23
0
    def run(self):
        print("training...")

        pbar = tqdm(range(self.start_epoch, self.nepochs + 1))
        pbar.set_description(f'Training IDR',)
        for epoch in pbar:

            if epoch in self.alpha_milestones:
                self.loss.alpha = self.loss.alpha * self.alpha_factor

            if epoch % 100 == 0:
                self.save_checkpoints(epoch)

            if epoch % self.plot_freq == 0:
                self.model.eval()
                if self.train_cameras:
                    self.pose_vecs.eval()
                self.train_dataset.change_sampling_idx(-1)
                indices, model_input, ground_truth = next(iter(self.plot_dataloader))

                model_input["intrinsics"] = utils.to_cuda(model_input["intrinsics"])
                model_input["uv"] = utils.to_cuda(model_input["uv"])
                model_input["object_mask"] = utils.to_cuda(model_input["object_mask"])

                if self.train_cameras:
                    pose_input = self.pose_vecs(utils.to_cuda(indices))
                    model_input['pose'] = pose_input
                else:
                    model_input['pose'] = utils.to_cuda(model_input['pose'])

                split = utils.split_input(model_input, self.total_pixels)
                res = []
                for s in split:
                    out = self.model(s)
                    res.append({
                        'points': out['points'].detach(),
                        'rgb_values': out['rgb_values'].detach(),
                        'network_object_mask': out['network_object_mask'].detach(),
                        'object_mask': out['object_mask'].detach()
                    })

                batch_size = ground_truth['rgb'].shape[0]
                model_outputs = utils.merge_output(res, self.total_pixels, batch_size)

                plt.plot(self.model,
                         indices,
                         model_outputs,
                         model_input['pose'],
                         ground_truth['rgb'],
                         self.plots_dir,
                         epoch,
                         self.img_res,
                         **self.plot_conf
                         )

                self.model.train()
                if self.train_cameras:
                    self.pose_vecs.train()

            self.train_dataset.change_sampling_idx(self.num_pixels)

            for data_index, (indices, model_input, ground_truth) in enumerate(self.train_dataloader):

                model_input["intrinsics"] = utils.to_cuda(model_input["intrinsics"])
                model_input["uv"] = utils.to_cuda(model_input["uv"])
                model_input["object_mask"] = utils.to_cuda(model_input["object_mask"])

                if self.train_cameras:
                    pose_input = self.pose_vecs(utils.to_cuda(indices))
                    model_input['pose'] = pose_input
                else:
                    model_input['pose'] = utils.to_cuda(model_input['pose'])

                model_outputs = self.model(model_input)
                loss_output = self.loss(model_outputs, ground_truth)

                loss = loss_output['loss']

                self.optimizer.zero_grad()
                if self.train_cameras:
                    self.optimizer_cam.zero_grad()

                loss.backward()

                self.optimizer.step()
                if self.train_cameras:
                    self.optimizer_cam.step()

            pbar.set_postfix({
                'loss':  loss.item(),
                'rgb_loss': loss_output['rgb_loss'].item(),
                'eikonal_loss': loss_output['eikonal_loss'].item(),
                'mask_loss': loss_output['mask_loss'].item(),
                'alpha': self.loss.alpha,
                'lr': self.scheduler.get_lr()[0]
                })

            self.scheduler.step()
Example #24
0
def evaluate(**kwargs):
    torch.set_default_dtype(torch.float32)

    conf = ConfigFactory.parse_file(kwargs['conf'])
    exps_folder_name = kwargs['exps_folder_name']
    evals_folder_name = kwargs['evals_folder_name']
    eval_cameras = kwargs['eval_cameras']
    eval_rendering = kwargs['eval_rendering']

    expname = conf.get_string('train.expname') + kwargs['expname']
    scene_id = kwargs['scene_id'] if kwargs['scene_id'] else conf.get_string(
        'dataset.scene_id', default=None)
    if scene_id:
        expname = expname + '_{0}'.format(scene_id)

    if kwargs['timestamp'] == 'latest':
        if os.path.exists(
                os.path.join('../', kwargs['exps_folder_name'], expname)):
            timestamps = os.listdir(
                os.path.join('../', kwargs['exps_folder_name'], expname))
            if (len(timestamps)) == 0:
                print('WRONG EXP FOLDER')
                exit()
            else:
                timestamp = sorted(timestamps)[-1]
        else:
            print('WRONG EXP FOLDER')
            exit()
    else:
        timestamp = kwargs['timestamp']

    utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name))
    expdir = os.path.join('../', exps_folder_name, expname)
    evaldir = os.path.join('../', evals_folder_name, expname)
    utils.mkdir_ifnotexists(evaldir)

    model = utils.get_class(
        conf.get_string('train.model_class'))(conf=conf.get_config('model'))
    if torch.cuda.is_available():
        model.cuda()

    dataset_conf = conf.get_config('dataset')
    if kwargs['scene_id']:
        dataset_conf['scene_id'] = kwargs['scene_id']
    eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(
        eval_cameras, **dataset_conf)

    # settings for camera optimization
    scale_mat = eval_dataset.get_scale_mat()
    if eval_cameras:
        num_images = len(eval_dataset)
        pose_vecs = utils.to_cuda(
            torch.nn.Embedding(num_images, 7, sparse=True))
        pose_vecs.weight.data.copy_(eval_dataset.get_pose_init())

        gt_pose = eval_dataset.get_gt_pose()

    if eval_rendering:
        eval_dataloader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=eval_dataset.collate_fn)
        total_pixels = eval_dataset.total_pixels
        img_res = eval_dataset.img_res

    old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints')

    saved_model_state = torch.load(
        os.path.join(old_checkpnts_dir, 'ModelParameters',
                     str(kwargs['checkpoint']) + ".pth"))
    model.load_state_dict(saved_model_state["model_state_dict"])
    epoch = saved_model_state['epoch']

    if eval_cameras:
        data = torch.load(
            os.path.join(old_checkpnts_dir, 'CamParameters',
                         str(kwargs['checkpoint']) + ".pth"))
        pose_vecs.load_state_dict(data["pose_vecs_state_dict"])

    ####################################################################################################################
    print("evaluating...")

    model.eval()
    if eval_cameras:
        pose_vecs.eval()

    with torch.no_grad():
        if eval_cameras:
            gt_Rs = gt_pose[:, :3, :3].double()
            gt_ts = gt_pose[:, :3, 3].double()

            pred_Rs = rend_util.quat_to_rot(
                pose_vecs.weight.data[:, :4]).cpu().double()
            pred_ts = pose_vecs.weight.data[:, 4:].cpu().double()

            R_opt, t_opt, c_opt, R_fixed, t_fixed = get_cameras_accuracy(
                pred_Rs, gt_Rs, pred_ts, gt_ts)

            cams_transformation = np.eye(4, dtype=np.double)
            cams_transformation[:3, :3] = c_opt * R_opt
            cams_transformation[:3, 3] = t_opt

        mesh = plt.get_surface_mesh(
            sdf=lambda x: model.geometry_network(x)[:, 0],
            resolution=kwargs['resolution'])

        # Transform to world coordinates
        if eval_cameras:
            mesh.apply_transform(cams_transformation)
        else:
            mesh.apply_transform(scale_mat)

        # Taking the biggest connected component
        components = mesh.split(only_watertight=False)
        areas = np.array([c.area for c in components], dtype=np.float)
        mesh_clean = components[areas.argmax()]
        mesh_clean.export(
            '{0}/surface_world_coordinates_{1}.ply'.format(evaldir, epoch),
            'ply')

    if eval_rendering:
        images_dir = '{0}/rendering'.format(evaldir)
        utils.mkdir_ifnotexists(images_dir)

        psnrs = []
        for data_index, (indices, model_input,
                         ground_truth) in enumerate(eval_dataloader):
            model_input["intrinsics"] = utils.to_cuda(
                model_input["intrinsics"])
            model_input["uv"] = utils.to_cuda(model_input["uv"])
            model_input["object_mask"] = utils.to_cuda(
                model_input["object_mask"])

            if eval_cameras:
                pose_input = pose_vecs(utils.to_cuda(indices))
                model_input['pose'] = pose_input
            else:
                model_input['pose'] = utils.to_cuda(model_input['pose'])

            split = utils.split_input(model_input, total_pixels)
            res = []
            for s in split:
                out = model(s)
                res.append({
                    'rgb_values': out['rgb_values'].detach(),
                })

            batch_size = ground_truth['rgb'].shape[0]
            model_outputs = utils.merge_output(res, total_pixels, batch_size)
            rgb_eval = model_outputs['rgb_values']
            rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3)

            rgb_eval = (rgb_eval + 1.) / 2.
            rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0]
            rgb_eval = rgb_eval.transpose(1, 2, 0)
            img = Image.fromarray((rgb_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}.png'.format(images_dir,
                                               '%03d' % indices[0]))

            rgb_gt = ground_truth['rgb']
            rgb_gt = (rgb_gt + 1.) / 2.
            rgb_gt = plt.lin2img(rgb_gt, img_res).numpy()[0]
            rgb_gt = rgb_gt.transpose(1, 2, 0)

            mask = model_input['object_mask']
            mask = plt.lin2img(mask.unsqueeze(-1), img_res).cpu().numpy()[0]
            mask = mask.transpose(1, 2, 0)

            rgb_eval_masked = rgb_eval * mask
            rgb_gt_masked = rgb_gt * mask

            psnr = calculate_psnr(rgb_eval_masked, rgb_gt_masked, mask)
            psnrs.append(psnr)

        psnrs = np.array(psnrs).astype(np.float64)
        print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}".
              format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scene_id))
Example #25
0
def evaluate(**kwargs):
    torch.set_default_dtype(torch.float32)

    conf = ConfigFactory.parse_file(kwargs['conf'])
    exps_folder_name = kwargs['exps_folder_name']
    evals_folder_name = kwargs['evals_folder_name']
    timestamp = '2020'
    checkpoint = '2000'

    expname = conf.get_string('train.expname')

    geometry_id = kwargs['geometry_id']
    appearance_id = kwargs['appearance_id']

    utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name))
    expdir_geometry = os.path.join('../', exps_folder_name,
                                   expname + '_{0}'.format(geometry_id))
    expdir_appearance = os.path.join('../', exps_folder_name,
                                     expname + '_{0}'.format(appearance_id))
    evaldir = os.path.join(
        '../', evals_folder_name,
        expname + '_{0}_{1}'.format(geometry_id, appearance_id))
    utils.mkdir_ifnotexists(evaldir)

    model = utils.get_class(
        conf.get_string('train.model_class'))(conf=conf.get_config('model'))
    if torch.cuda.is_available():
        model.cuda()

    # Load geometry network model
    old_checkpnts_dir = os.path.join(expdir_geometry, timestamp, 'checkpoints')
    saved_model_state = torch.load(
        os.path.join(old_checkpnts_dir, 'ModelParameters',
                     checkpoint + ".pth"))
    model.load_state_dict(saved_model_state["model_state_dict"])

    # Load rendering network model
    model_fake = utils.get_class(
        conf.get_string('train.model_class'))(conf=conf.get_config('model'))
    if torch.cuda.is_available():
        model_fake.cuda()
    old_checkpnts_dir = os.path.join(expdir_appearance, timestamp,
                                     'checkpoints')
    saved_model_state = torch.load(
        os.path.join(old_checkpnts_dir, 'ModelParameters',
                     checkpoint + ".pth"))
    model_fake.load_state_dict(saved_model_state["model_state_dict"])

    model.rendering_network = model_fake.rendering_network

    dataset_conf = conf.get_config('dataset')
    dataset_conf['scene_id'] = geometry_id
    eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(
        False, **dataset_conf)

    eval_dataloader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=1,
        shuffle=True,
        collate_fn=eval_dataset.collate_fn)
    total_pixels = eval_dataset.total_pixels
    img_res = eval_dataset.img_res

    ####################################################################################################################
    print("evaluating...")

    model.eval()

    gt_pose = utils.to_cuda(eval_dataset.get_gt_pose(scaled=True))
    gt_quat = rend_util.rot_to_quat(gt_pose[:, :3, :3])
    gt_pose_vec = torch.cat([gt_quat, gt_pose[:, :3, 3]], 1)

    indices_all = [11, 16, 34, 28, 11]
    pose = gt_pose_vec[indices_all, :]
    t_in = np.array([0, 2, 3, 5, 6]).astype(np.float32)

    n_inter = 5
    t_out = np.linspace(t_in[0], t_in[-1],
                        n_inter * t_in[-1]).astype(np.float32)

    scales = np.array([4.2, 4.2, 3.8, 3.8, 4.2]).astype(np.float32)

    s_new = CubicSpline(t_in, scales, bc_type='periodic')
    s_new = s_new(t_out)

    q_new = CubicSpline(t_in,
                        pose[:, :4].detach().cpu().numpy(),
                        bc_type='periodic')
    q_new = q_new(t_out)
    q_new = q_new / np.linalg.norm(q_new, 2, 1)[:, None]
    q_new = utils.to_cuda(torch.from_numpy(q_new)).float()

    images_dir = '{0}/novel_views_rendering'.format(evaldir)
    utils.mkdir_ifnotexists(images_dir)

    indices, model_input, ground_truth = next(iter(eval_dataloader))

    for i, (new_q, scale) in enumerate(zip(q_new, s_new)):
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        new_q = new_q.unsqueeze(0)
        new_t = -rend_util.quat_to_rot(new_q)[:, :, 2] * scale

        new_p = utils.to_cuda(torch.eye(4).float()).unsqueeze(0)
        new_p[:, :3, :3] = rend_util.quat_to_rot(new_q)
        new_p[:, :3, 3] = new_t

        sample = {
            "object_mask":
            utils.to_cuda(torch.zeros_like(model_input['object_mask'])).bool(),
            "uv":
            utils.to_cuda(model_input['uv']),
            "intrinsics":
            utils.to_cuda(model_input['intrinsics']),
            "pose":
            new_p
        }

        split = utils.split_input(sample, total_pixels)
        res = []
        for s in split:
            out = model(s)
            res.append({
                'rgb_values': out['rgb_values'].detach(),
            })

        batch_size = 1
        model_outputs = utils.merge_output(res, total_pixels, batch_size)
        rgb_eval = model_outputs['rgb_values']
        rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3)

        rgb_eval = (rgb_eval + 1.) / 2.
        rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0]
        rgb_eval = rgb_eval.transpose(1, 2, 0)
        img = Image.fromarray((rgb_eval * 255).astype(np.uint8))
        img.save('{0}/eval_{1}.png'.format(images_dir, '%03d' % i))
Example #26
0
def get_surface_high_res_mesh(sdf, resolution=100):
    # get low res mesh to sample point cloud
    mesh_low_res = get_surface_mesh(sdf, 100)

    recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]
    recon_pc = utils.to_cuda(torch.from_numpy(recon_pc).float())

    # Center and align the recon pc
    s_mean = recon_pc.mean(dim=0)
    s_cov = recon_pc - s_mean
    s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)
    vecs = torch.eig(s_cov, True)[1].transpose(0, 1)
    if torch.det(vecs) < 0:
        vecs = torch.mm(
            utils.to_cuda(torch.tensor([[1, 0, 0], [0, 0, 1],
                                        [0, 1, 0]])).float(), vecs)
    helper = torch.bmm(
        vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1),
        (recon_pc - s_mean).unsqueeze(-1)).squeeze()

    grid_aligned = get_grid(helper.cpu(), resolution)

    grid_points = grid_aligned['grid_points'].cpu()
    s_mean = s_mean.cpu()
    vecs = vecs.cpu()
    g = []
    for i, pnts in enumerate(torch.split(grid_points, 10000, dim=0)):
        g.append(
            torch.bmm(
                vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2),
                pnts.unsqueeze(-1)).squeeze() + s_mean)
    grid_points = torch.cat(g, dim=0)

    # MC to new grid
    points = grid_points
    z = []
    for i, pnts in enumerate(torch.split(points, 10000, dim=0)):
        z.append(sdf(utils.to_cuda(pnts)).detach().cpu().numpy())
    z = np.concatenate(z, axis=0)

    meshexport = None
    if (not (np.min(z) > 0 or np.max(z) < 0)):

        z = z.astype(np.float32)

        verts, faces, normals, values = measure.marching_cubes_lewiner(
            volume=z.reshape(grid_aligned['xyz'][1].shape[0],
                             grid_aligned['xyz'][0].shape[0],
                             grid_aligned['xyz'][2].shape[0]).transpose(
                                 [1, 0, 2]),
            level=0,
            spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
                     grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
                     grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1]))

        verts = utils.to_cuda(torch.from_numpy(verts)).float()
        vecs = utils.to_cuda(vecs)
        verts = torch.bmm(
            vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2),
            verts.unsqueeze(-1)).squeeze()
        verts = verts.cpu()
        verts = (verts + grid_points[0]).numpy()

        meshexport = trimesh.Trimesh(verts, faces)

    return meshexport
Example #27
0
def plot(model, indices, model_outputs, pose, rgb_gt, path, epoch, img_res,
         plot_nimgs, max_depth, resolution):
    # arrange data to plot
    batch_size, num_samples, _ = rgb_gt.shape

    network_object_mask = model_outputs['network_object_mask']
    points = model_outputs['points'].reshape(batch_size, num_samples, 3)
    rgb_eval = model_outputs['rgb_values']
    rgb_eval = rgb_eval.reshape(batch_size, num_samples, 3)

    depth = utils.to_cuda(torch.ones(
        batch_size * num_samples)).float() * max_depth
    depth[network_object_mask] = rend_util.get_depth(
        points, pose).reshape(-1)[network_object_mask]
    depth = depth.reshape(batch_size, num_samples, 1)
    network_object_mask = network_object_mask.reshape(batch_size, -1)

    cam_loc, cam_dir = rend_util.get_camera_for_plot(pose)

    # plot rendered images
    plot_images(rgb_eval, rgb_gt, path, epoch, plot_nimgs, img_res)

    # plot depth maps
    plot_depth_maps(depth, path, epoch, plot_nimgs, img_res)

    data = []

    # plot surface
    surface_traces = get_surface_trace(
        path=path,
        epoch=epoch,
        sdf=lambda x: model.geometry_network(x)[:, 0],
        resolution=resolution)
    data.append(surface_traces[0])

    # plot cameras locations
    for i, loc, dir in zip(indices, cam_loc, cam_dir):
        data.append(
            get_3D_quiver_trace(loc.unsqueeze(0),
                                dir.unsqueeze(0),
                                name='camera_{0}'.format(i)))

    for i, p, m in zip(indices, points, network_object_mask):
        p = p[m]
        sampling_idx = torch.randperm(p.shape[0])[:2048]
        p = p[sampling_idx, :]

        val = model.geometry_network(p)
        caption = ["sdf: {0} ".format(v[0].item()) for v in val]

        data.append(
            get_3D_scatter_trace(p,
                                 name='intersection_points_{0}'.format(i),
                                 caption=caption))

    fig = go.Figure(data=data)
    scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False),
                      yaxis=dict(range=[-3, 3], autorange=False),
                      zaxis=dict(range=[-3, 3], autorange=False),
                      aspectratio=dict(x=1, y=1, z=1))
    fig.update_layout(scene=scene_dict,
                      width=1400,
                      height=1400,
                      showlegend=True)
    filename = '{0}/surface_{1}.html'.format(path, epoch)
    offline.plot(fig, filename=filename, auto_open=False)
Example #28
0
    def get_eikonal_loss(self, grad_theta):
        if grad_theta.shape[0] == 0:
            return utils.to_cuda(torch.tensor(0.0)).float()

        eikonal_loss = ((grad_theta.norm(2, dim=1) - 1)**2).mean()
        return eikonal_loss
Example #29
0
    def forward(self, input):

        # Parse model input
        intrinsics = input["intrinsics"]
        uv = input["uv"]
        pose = input["pose"]
        object_mask = input["object_mask"].reshape(-1)

        ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics)

        batch_size, num_pixels, _ = ray_dirs.shape

        self.geometry_network.eval()
        with torch.no_grad():
            points, network_object_mask, dists = self.ray_tracer(sdf=lambda x: self.geometry_network(x)[:, 0],
                                                                 cam_loc=cam_loc,
                                                                 object_mask=object_mask,
                                                                 ray_directions=ray_dirs)
        self.geometry_network.train()

        points = (cam_loc.unsqueeze(1) + dists.reshape(batch_size, num_pixels, 1) * ray_dirs).reshape(-1, 3)

        sdf_output = self.geometry_network(points)[:, 0:1]
        ray_dirs = ray_dirs.reshape(-1, 3)

        if self.training:
            surface_mask = network_object_mask & object_mask
            surface_points = points[surface_mask]
            surface_dists = dists[surface_mask].unsqueeze(-1)
            surface_ray_dirs = ray_dirs[surface_mask]
            surface_cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)[surface_mask]
            surface_output = sdf_output[surface_mask]
            N = surface_points.shape[0]

            # Sample points for the eikonal loss
            eik_bounding_box = self.object_bounding_sphere
            n_eik_points = batch_size * num_pixels // 2
            eikonal_points = utils.to_cuda(torch.empty(n_eik_points, 3).uniform_(-eik_bounding_box, eik_bounding_box))
            eikonal_pixel_points = points.clone()
            eikonal_pixel_points = eikonal_pixel_points.detach()
            eikonal_points = torch.cat([eikonal_points, eikonal_pixel_points], 0)

            points_all = torch.cat([surface_points, eikonal_points], dim=0)

            output = self.geometry_network(surface_points)
            surface_sdf_values = output[:N, 0:1].detach()

            g = self.geometry_network.gradient(points_all)
            surface_points_grad = g[:N, 0, :].clone().detach()
            grad_theta = g[N:, 0, :]

            differentiable_surface_points = self.sample_network(surface_output,
                                                                surface_sdf_values,
                                                                surface_points_grad,
                                                                surface_dists,
                                                                surface_cam_loc,
                                                                surface_ray_dirs)

        else:
            surface_mask = network_object_mask
            differentiable_surface_points = points[surface_mask]
            grad_theta = None

        view = -ray_dirs[surface_mask]

        rgb_values = utils.to_cuda(torch.ones_like(points).float())
        if differentiable_surface_points.shape[0] > 0:
            rgb_values[surface_mask] = self.get_rbg_value(differentiable_surface_points, view)

        output = {
            'points': points,
            'rgb_values': rgb_values,
            'sdf_output': sdf_output,
            'network_object_mask': network_object_mask,
            'object_mask': object_mask,
            'grad_theta': grad_theta
        }

        return output
Example #30
0
    def forward(self, sdf, cam_loc, object_mask, ray_directions):

        batch_size, num_pixels, _ = ray_directions.shape

        sphere_intersections, mask_intersect = rend_util.get_sphere_intersection(
            cam_loc, ray_directions, r=self.object_bounding_sphere)

        curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis = \
            self.sphere_tracing(batch_size, num_pixels, sdf, cam_loc, ray_directions, mask_intersect, sphere_intersections)

        network_object_mask = (acc_start_dis < acc_end_dis)

        # The non convergent rays should be handled by the sampler
        sampler_mask = unfinished_mask_start
        sampler_net_obj_mask = utils.to_cuda(
            torch.zeros_like(sampler_mask).bool())
        if sampler_mask.sum() > 0:
            sampler_min_max = utils.to_cuda(
                torch.zeros((batch_size, num_pixels, 2)))
            sampler_min_max.reshape(-1, 2)[sampler_mask,
                                           0] = acc_start_dis[sampler_mask]
            sampler_min_max.reshape(-1, 2)[sampler_mask,
                                           1] = acc_end_dis[sampler_mask]

            sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler(
                sdf, cam_loc, object_mask, ray_directions, sampler_min_max,
                sampler_mask)

            curr_start_points[sampler_mask] = sampler_pts[sampler_mask]
            acc_start_dis[sampler_mask] = sampler_dists[sampler_mask]
            network_object_mask[sampler_mask] = sampler_net_obj_mask[
                sampler_mask]

        if self.verbose:
            print(
                '----------------------------------------------------------------'
            )
            print('RayTracing: object = {0}/{1}, secant on {2}/{3}.'.format(
                network_object_mask.sum(), len(network_object_mask),
                sampler_net_obj_mask.sum(), sampler_mask.sum()))
            print(
                '----------------------------------------------------------------'
            )

        if not self.training:
            return curr_start_points, \
                   network_object_mask, \
                   acc_start_dis

        ray_directions = ray_directions.reshape(-1, 3)
        mask_intersect = mask_intersect.reshape(-1)

        in_mask = ~network_object_mask & object_mask & ~sampler_mask
        out_mask = ~object_mask & ~sampler_mask

        mask_left_out = (in_mask | out_mask) & ~mask_intersect
        if mask_left_out.sum(
        ) > 0:  # project the origin to the not intersect points on the sphere
            cam_left_out = cam_loc.unsqueeze(1).repeat(
                1, num_pixels, 1).reshape(-1, 3)[mask_left_out]
            rays_left_out = ray_directions[mask_left_out]
            acc_start_dis[mask_left_out] = -torch.bmm(
                rays_left_out.view(-1, 1, 3), cam_left_out.view(-1, 3,
                                                                1)).squeeze()
            curr_start_points[mask_left_out] = cam_left_out + acc_start_dis[
                mask_left_out].unsqueeze(1) * rays_left_out

        mask = (in_mask | out_mask) & mask_intersect

        if mask.sum() > 0:
            min_dis[network_object_mask
                    & out_mask] = acc_start_dis[network_object_mask & out_mask]

            min_mask_points, min_mask_dist = self.minimal_sdf_points(
                num_pixels, sdf, cam_loc, ray_directions, mask, min_dis,
                max_dis)

            curr_start_points[mask] = min_mask_points
            acc_start_dis[mask] = min_mask_dist

        return curr_start_points, \
               network_object_mask, \
               acc_start_dis