Example #1
0
    def plot_shapes(self, epoch, path=None, with_cuts=False):
        # plot network validation shapes
        with torch.no_grad():

            self.network.eval()

            if not path:
                path = self.plots_dir

            indices = torch.tensor(np.random.choice(self.data.shape[0], self.points_batch, False))

            pnts = self.data[indices, :3]

            plot_surface(with_points=True,
                         points=pnts,
                         decoder=self.network,
                         path=path,
                         epoch=epoch,
                         shapename=self.expname,
                         **self.conf.get_config('plot'))

            if with_cuts:
                plot_cuts(points=pnts,
                          decoder=self.network,
                          path=path,
                          epoch=epoch,
                          near_zero=False)
Example #2
0
    def plot_validation_shapes(self, epoch, with_cuts=False):
        # plot network validation shapes
        with torch.no_grad():

            print('plot validation epoch: ', epoch)

            self.network.eval()
            pnts, normals, idx = next(iter(self.eval_dataloader))
            pnts = utils.to_cuda(pnts)

            pnts = self.add_latent(pnts, idx)
            latent = self.lat_vecs[idx[0]]

            shapename = str.join('_', self.ds.get_info(idx))

            plot_surface(with_points=True,
                         points=pnts,
                         decoder=self.network,
                         latent=latent,
                         path=self.plots_dir,
                         epoch=epoch,
                         shapename=shapename,
                         **self.conf.get_config('plot'))

            if with_cuts:
                plot_cuts(points=pnts,
                          decoder=self.network,
                          latent=latent,
                          path=self.plots_dir,
                          epoch=epoch,
                          near_zero=False)
Example #3
0
def interpolate(network, interval, experiment_directory, checkpoint,
                split_file, epoch, resolution, uniform_grid):

    with open(split_file, "r") as f:
        split = json.load(f)

    ds = utils.get_class(conf.get_string('train.dataset'))(
        split=split,
        dataset_path=conf.get_string('train.dataset_path'),
        with_normals=True)

    points_1, normals_1, index_1 = ds[0]
    points_2, normals_2, index_2 = ds[1]

    pnts = torch.cat([points_1, points_2], dim=0).cuda()

    name_1 = str.join('_', ds.get_info(0))
    name_2 = str.join('_', ds.get_info(0))

    name = name_1 + '_and_' + name_2

    utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'interpolate'))
    utils.mkdir_ifnotexists(
        os.path.join(experiment_directory, 'interpolate', str(checkpoint)))
    utils.mkdir_ifnotexists(
        os.path.join(experiment_directory, 'interpolate', str(checkpoint),
                     name))

    my_path = os.path.join(experiment_directory, 'interpolate',
                           str(checkpoint), name)

    latent_1 = optimize_latent(points_1.cuda(), normals_1.cuda(), conf, 800,
                               network, 5e-3)
    latent_2 = optimize_latent(points_2.cuda(), normals_2.cuda(), conf, 800,
                               network, 5e-3)

    pnts = torch.cat([latent_1.repeat(pnts.shape[0], 1), pnts], dim=-1)

    with torch.no_grad():
        network.eval()

        for alpha in np.linspace(0, 1, interval):

            latent = (latent_1 * (1 - alpha)) + (latent_2 * alpha)

            plt.plot_surface(with_points=False,
                             points=pnts,
                             decoder=network,
                             latent=latent,
                             path=my_path,
                             epoch=epoch,
                             shapename=str(alpha),
                             resolution=resolution,
                             mc_value=0,
                             is_uniform_grid=uniform_grid,
                             verbose=True,
                             save_html=False,
                             save_ply=True,
                             overwrite=True,
                             connected=True)
Example #4
0
def evaluate(network, experiment_directory, conf, checkpoint, split_file, epoch, resolution, uniform_grid):

    my_path = os.path.join(experiment_directory, 'evaluation', str(checkpoint))

    utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'evaluation'))
    utils.mkdir_ifnotexists(my_path)

    with open(split_file, "r") as f:
        split = json.load(f)

    ds = utils.get_class(conf.get_string('train.dataset'))(split=split, dataset_path=conf.get_string('train.dataset_path'), with_normals=True)

    total_files = len(ds)
    print("total files : {0}".format(total_files))
    counter = 0
    dataloader = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=True, num_workers=1, drop_last=False, pin_memory=True)

    for (input_pc, normals, index) in dataloader:

        input_pc = input_pc.cuda().squeeze()
        normals = normals.cuda().squeeze()

        print(counter)
        counter = counter + 1

        network.train()

        latent = optimize_latent(input_pc, normals, conf, 800, network, lr=5e-3)

        all_latent = latent.repeat(input_pc.shape[0], 1)

        points = torch.cat([all_latent,input_pc], dim=-1)

        shapename = str.join('_', ds.get_info(index))

        with torch.no_grad():

            network.eval()

            plt.plot_surface(with_points=True,
                             points=points,
                             decoder=network,
                             latent=latent,
                             path=my_path,
                             epoch=epoch,
                             shapename=shapename,
                             resolution=resolution,
                             mc_value=0,
                             is_uniform_grid=uniform_grid,
                             verbose=True,
                             save_html=True,
                             save_ply=True,
                             overwrite=True,
                             connected=True)
Example #5
0
    def run(self):
        timing_log = []
        for epoch in range(self.start_epoch, self.nepochs + 2):
            start = time.time()

            if epoch % 100 == 0:
                self.save_checkpoints(epoch)
            if epoch % self.conf.get_int(
                    'train.plot_frequency') == 0 and epoch >= 0:
                with torch.no_grad():

                    self.network.eval()

                    pnts, _, idx = next(iter(self.eval_dataloader))
                    pnts = pnts.cuda()

                    if (self.parallel):
                        decoder = self.network.module.decoder
                        encoder = self.network.module.encoder
                    else:
                        decoder = self.network.decoder
                        encoder = self.network.encoder

                    if self.latent_size > 0:
                        latent = encoder(pnts)[0]

                        if (type(latent) is tuple):
                            latent = latent[0]
                        pnts = torch.cat([
                            latent.unsqueeze(1).repeat(1, pnts.shape[1], 1),
                            pnts
                        ],
                                         dim=-1)[0]
                    else:
                        latent = None
                        pnts = pnts[0]

                    plot_surface(with_points=True,
                                 points=pnts,
                                 decoder=decoder,
                                 latent=latent,
                                 path=self.plots_dir,
                                 epoch=epoch,
                                 in_epoch=0,
                                 shapefile=self.ds.npyfiles_mnfld[idx],
                                 **self.conf.get_config('plot'))
                    self.network.train()

            self.network.train()
            if (self.adjust_lr):
                self.adjust_learning_rate(epoch)
            for data_index, (pnts_mnfld, sample_nonmnfld,
                             indices) in enumerate(self.dataloader):

                pnts_mnfld = pnts_mnfld.cuda()
                sample_nonmnfld = sample_nonmnfld.cuda()
                xyz_nonmnfld = sample_nonmnfld[:, :, :3]
                dist_nonmnfld = sample_nonmnfld[:, :, 3].reshape(-1)

                outputs = self.network(xyz_nonmnfld, pnts_mnfld)
                loss_res = self.loss(
                    manifold_pnts_pred=outputs['manifold_pnts_pred'],
                    nonmanifold_pnts_pred=outputs['nonmanifold_pnts_pred'],
                    nonmanifold_gt=dist_nonmnfld,
                    weight=None,
                    latent_reg=outputs["latent_reg"])
                loss = loss_res["loss"]

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                logging.debug("expname : {0}".format(self.expname))
                logging.debug(
                    "timestamp: {0} , epoch : {1}, data_index : {2} , loss : {3}, reconstruction loss : {4} , vae loss : {5} "
                    .format(self.timestamp, epoch, data_index,
                            loss_res['loss'].item(),
                            loss_res['recon_term'].item(),
                            loss_res['reg_term'].item()))
                for param_group in self.optimizer.param_groups:
                    logging.debug("param group lr : {0}".format(
                        param_group["lr"]))

            end = time.time()
            seconds_elapsed_epoch = end - start
            timing_log.append(seconds_elapsed_epoch)
Example #6
0
def optimize_latent(latent, ds, itemindex, decoder, path, epoch, resolution,
                    conf):
    latent.detach_()
    latent.requires_grad_()
    lr = 1.0e-3
    optimizer = torch.optim.Adam([latent], lr=lr)
    loss_func = utils.get_class(conf.get_string('network.loss.loss_type'))(
        **conf.get_config('network.loss.properties'))

    num_iterations = 800

    decreased_by = 10
    adjust_lr_every = int(num_iterations / 2)
    for e in range(num_iterations):
        input_pc, sample_nonmnfld, _ = ds[itemindex]
        input_pc = utils.get_cuda_ifavailable(input_pc).unsqueeze(0)
        sample_nonmnfld = utils.get_cuda_ifavailable(
            sample_nonmnfld).unsqueeze(0)

        non_mnfld_pnts = sample_nonmnfld[:, :, :3]
        dist_nonmnfld = sample_nonmnfld[:, :, 3].reshape(-1)

        adjust_learning_rate(lr, optimizer, e, decreased_by, adjust_lr_every)

        optimizer.zero_grad()
        non_mnfld_pnts_with_latent = torch.cat([
            latent.unsqueeze(1).repeat(1, non_mnfld_pnts.shape[1], 1),
            non_mnfld_pnts
        ],
                                               dim=-1)
        nonmanifold_pnts_pred = decoder(
            non_mnfld_pnts_with_latent.view(
                -1, non_mnfld_pnts_with_latent.shape[-1]))

        loss_res = loss_func(manifold_pnts_pred=None,
                             nonmanifold_pnts_pred=nonmanifold_pnts_pred,
                             nonmanifold_gt=dist_nonmnfld,
                             weight=None)
        loss = loss_res["loss"]

        loss.backward()
        optimizer.step()
        print("iteration : {0} , loss {1}".format(e, loss.item()))
        print("mean {0} , std {1}".format(latent.mean().item(),
                                          latent.std().item()))

    with torch.no_grad():
        reconstruction = plt.plot_surface(
            with_points=False,
            points=torch.cat([
                latent.unsqueeze(1).repeat(1, input_pc.shape[1], 1), input_pc
            ],
                             dim=-1)[0],
            decoder=network.decoder,
            latent=latent,
            path=path,
            epoch=epoch,
            in_epoch=ds.npyfiles_mnfld[itemindex].split('/')[-3] + '_' +
            ds.npyfiles_mnfld[itemindex].split('/')[-1].split('.npy')[0] +
            '_after',
            shapefile=ds.npyfiles_mnfld[itemindex],
            resolution=resolution,
            mc_value=0,
            is_uniform_grid=True,
            verbose=True,
            save_html=False,
            save_ply=True,
            overwrite=True)
        return reconstruction
Example #7
0
def evaluate(network, exps_dir, experiment_name, timestamp, split_filename,
             epoch, conf, with_opt, resolution, compute_dist_to_gt):

    utils.mkdir_ifnotexists(
        os.path.join('../', exps_dir, experiment_name, timestamp,
                     'evaluation'))
    utils.mkdir_ifnotexists(
        os.path.join('../', exps_dir, experiment_name, timestamp, 'evaluation',
                     split_filename.split('/')[-1].split('.json')[0]))
    path = os.path.join('../', exps_dir, experiment_name, timestamp,
                        'evaluation',
                        split_filename.split('/')[-1].split('.json')[0],
                        str(epoch))
    utils.mkdir_ifnotexists(path)

    dataset_path = conf.get_string('train.dataset_path')
    train_data_split = conf.get_string('train.data_split')
    latent_size = conf.get_int('train.latent_size')

    if (train_data_split == 'none'):
        ds = ReconDataSet(split=None,
                          dataset_path=dataset_path,
                          dist_file_name=None)
    else:
        dist_file_name = conf.get_string('train.dist_file_name')
        with open(split_filename, "r") as f:
            split = json.load(f)

        chamfer_results = []
        plot_cmpr = True
        ds = DFaustDataSet(split=split,
                           dataset_path=dataset_path,
                           dist_file_name=dist_file_name,
                           with_gt=True)
        total_files = len(ds)
        logging.info("total files : {0}".format(total_files))
    counter = 0
    dataloader = torch.utils.data.DataLoader(ds,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=1,
                                             drop_last=False,
                                             pin_memory=True)

    for data in dataloader:

        counter = counter + 1

        logging.info("evaluating " + ds.npyfiles_mnfld[data[-1]])

        input_pc = data[0].cuda()
        if latent_size > 0:
            latent = network.encoder(input_pc)
            if (type(latent) is tuple):
                latent = latent[0]
            points = torch.cat([
                latent.unsqueeze(1).repeat(1, input_pc.shape[1], 1), input_pc
            ],
                               dim=-1)[0]
        else:
            latent = None
            points = input_pc[0]

        reconstruction = plt.plot_surface(
            with_points=False,
            points=points,
            decoder=network.decoder,
            latent=latent,
            path=path,
            epoch=epoch,
            in_epoch=ds.npyfiles_mnfld[data[-1].item()].split('/')[-3] + '_' +
            ds.npyfiles_mnfld[data[-1].item()].split('/')[-1].split('.npy')[0]
            + '_before',
            shapefile=ds.npyfiles_mnfld[data[-1].item()],
            resolution=resolution,
            mc_value=0,
            is_uniform_grid=True,
            verbose=True,
            save_html=False,
            save_ply=True,
            overwrite=True)
        if (with_opt):
            recon_after_latentopt = optimize_latent(latent, ds, data[-1],
                                                    network.decoder, path,
                                                    epoch, resolution, conf)

        if compute_dist_to_gt:
            gt_mesh_filename = ds.gt_files[data[-1]]
            normalization_params_filename = ds.normalization_files[data[-1]]

            logging.debug("normalization params are " +
                          normalization_params_filename)

            ground_truth_points = trimesh.Trimesh(
                trimesh.sample.sample_surface(trimesh.load(gt_mesh_filename),
                                              30000)[0])

            normalization_params = np.load(normalization_params_filename,
                                           allow_pickle=True)

            scale = normalization_params.item()['scale']
            center = normalization_params.item()['center']

            chamfer_dist = utils.compute_trimesh_chamfer(
                gt_points=ground_truth_points,
                gen_mesh=reconstruction,
                offset=-center,
                scale=1. / scale,
            )

            chamfer_dist_scan = utils.compute_trimesh_chamfer(
                gt_points=trimesh.Trimesh(input_pc[0].cpu().numpy()),
                gen_mesh=reconstruction,
                offset=0,
                scale=1.,
                one_side=True)

            logging.debug("chamfer distance: " + str(chamfer_dist))

            if (with_opt):
                chamfer_dist_after_opt = utils.compute_trimesh_chamfer(
                    gt_points=ground_truth_points,
                    gen_mesh=recon_after_latentopt,
                    offset=-center,
                    scale=1. / scale,
                )

                chamfer_dist_scan_after_opt = utils.compute_trimesh_chamfer(
                    gt_points=trimesh.Trimesh(input_pc[0].cpu().numpy()),
                    gen_mesh=recon_after_latentopt,
                    offset=0,
                    scale=1.,
                    one_side=True)

                chamfer_results.append(
                    (ds.gt_files[data[-1]], chamfer_dist, chamfer_dist_scan,
                     chamfer_dist_after_opt, chamfer_dist_scan_after_opt))
            else:
                chamfer_results.append(
                    (ds.gt_files[data[-1]], chamfer_dist, chamfer_dist_scan))

            if (plot_cmpr):
                if (with_opt):
                    fig = make_subplots(rows=2,
                                        cols=2,
                                        specs=[[{
                                            "type": "scene"
                                        }, {
                                            "type": "scene"
                                        }],
                                               [{
                                                   "type": "scene"
                                               }, {
                                                   "type": "scene"
                                               }]],
                                        subplot_titles=[
                                            "Input", "Registration", "Ours",
                                            "Ours after opt"
                                        ])

                else:
                    fig = make_subplots(rows=1,
                                        cols=3,
                                        specs=[[{
                                            "type": "scene"
                                        }, {
                                            "type": "scene"
                                        }, {
                                            "type": "scene"
                                        }]],
                                        subplot_titles=("input pc", "Ours",
                                                        "Registration"))

                fig.layout.scene.update(
                    dict(xaxis=dict(range=[-1.5, 1.5], autorange=False),
                         yaxis=dict(range=[-1.5, 1.5], autorange=False),
                         zaxis=dict(range=[-1.5, 1.5], autorange=False),
                         aspectratio=dict(x=1, y=1, z=1)))
                fig.layout.scene2.update(
                    dict(xaxis=dict(range=[-1.5, 1.5], autorange=False),
                         yaxis=dict(range=[-1.5, 1.5], autorange=False),
                         zaxis=dict(range=[-1.5, 1.5], autorange=False),
                         aspectratio=dict(x=1, y=1, z=1)))
                fig.layout.scene3.update(
                    dict(xaxis=dict(range=[-1.5, 1.5], autorange=False),
                         yaxis=dict(range=[-1.5, 1.5], autorange=False),
                         zaxis=dict(range=[-1.5, 1.5], autorange=False),
                         aspectratio=dict(x=1, y=1, z=1)))
                if (with_opt):
                    fig.layout.scene4.update(
                        dict(xaxis=dict(range=[-1.5, 1.5], autorange=False),
                             yaxis=dict(range=[-1.5, 1.5], autorange=False),
                             zaxis=dict(range=[-1.5, 1.5], autorange=False),
                             aspectratio=dict(x=1, y=1, z=1)))

                scan_mesh = trimesh.load(ds.scans_files[data[-1]])

                scan_mesh.vertices = scan_mesh.vertices - center

                def tri_indices(simplices):
                    return ([triplet[c] for triplet in simplices]
                            for c in range(3))

                I, J, K = tri_indices(scan_mesh.faces)
                color = '#ffffff'
                trace = go.Mesh3d(x=scan_mesh.vertices[:, 0],
                                  y=scan_mesh.vertices[:, 1],
                                  z=scan_mesh.vertices[:, 2],
                                  i=I,
                                  j=J,
                                  k=K,
                                  name='scan',
                                  color=color,
                                  opacity=1.0,
                                  flatshading=False,
                                  lighting=dict(diffuse=1,
                                                ambient=0,
                                                specular=0),
                                  lightposition=dict(x=0, y=0, z=-1))
                fig.add_trace(trace, row=1, col=1)

                I, J, K = tri_indices(reconstruction.faces)
                color = '#ffffff'
                trace = go.Mesh3d(x=reconstruction.vertices[:, 0],
                                  y=reconstruction.vertices[:, 1],
                                  z=reconstruction.vertices[:, 2],
                                  i=I,
                                  j=J,
                                  k=K,
                                  name='our',
                                  color=color,
                                  opacity=1.0,
                                  flatshading=False,
                                  lighting=dict(diffuse=1,
                                                ambient=0,
                                                specular=0),
                                  lightposition=dict(x=0, y=0, z=-1))
                if (with_opt):
                    fig.add_trace(trace, row=2, col=1)

                    I, J, K = tri_indices(recon_after_latentopt.faces)
                    color = '#ffffff'
                    trace = go.Mesh3d(x=recon_after_latentopt.vertices[:, 0],
                                      y=recon_after_latentopt.vertices[:, 1],
                                      z=recon_after_latentopt.vertices[:, 2],
                                      i=I,
                                      j=J,
                                      k=K,
                                      name='our_after_opt',
                                      color=color,
                                      opacity=1.0,
                                      flatshading=False,
                                      lighting=dict(diffuse=1,
                                                    ambient=0,
                                                    specular=0),
                                      lightposition=dict(x=0, y=0, z=-1))
                    fig.add_trace(trace, row=2, col=2)
                else:
                    fig.add_trace(trace, row=1, col=2)

                gtmesh = trimesh.load(gt_mesh_filename)
                gtmesh.vertices = gtmesh.vertices - center
                I, J, K = tri_indices(gtmesh.faces)
                trace = go.Mesh3d(x=gtmesh.vertices[:, 0],
                                  y=gtmesh.vertices[:, 1],
                                  z=gtmesh.vertices[:, 2],
                                  i=I,
                                  j=J,
                                  k=K,
                                  name='gt',
                                  color=color,
                                  opacity=1.0,
                                  flatshading=False,
                                  lighting=dict(diffuse=1,
                                                ambient=0,
                                                specular=0),
                                  lightposition=dict(x=0, y=0, z=-1))
                if (with_opt):
                    fig.add_trace(trace, row=1, col=2)
                else:
                    fig.add_trace(trace, row=1, col=3)

                div = offline.plot(fig,
                                   include_plotlyjs=False,
                                   output_type='div',
                                   auto_open=False)
                div_id = div.split('=')[1].split()[0].replace("'", "").replace(
                    '"', '')
                if (with_opt):
                    js = '''
                                                    <script>
                                                    var gd = document.getElementById('{div_id}');
                                                    var isUnderRelayout = false
    
                                                    gd.on('plotly_relayout', () => {{
                                                      console.log('relayout', isUnderRelayout)
                                                      if (!isUnderRelayout) {{
                                                            Plotly.relayout(gd, 'scene2.camera', gd.layout.scene.camera)
                                                              .then(() => {{ isUnderRelayout = false }}  )
                                                            Plotly.relayout(gd, 'scene3.camera', gd.layout.scene.camera)
                                                              .then(() => {{ isUnderRelayout = false }}  )
                                                            Plotly.relayout(gd, 'scene4.camera', gd.layout.scene.camera)
                                                              .then(() => {{ isUnderRelayout = false }}  )
                                                          }}
    
                                                      isUnderRelayout = true;
                                                    }})
                                                    </script>'''.format(
                        div_id=div_id)
                else:
                    js = '''
                                    <script>
                                    var gd = document.getElementById('{div_id}');
                                    var isUnderRelayout = false
        
                                    gd.on('plotly_relayout', () => {{
                                      console.log('relayout', isUnderRelayout)
                                      if (!isUnderRelayout) {{
                                            Plotly.relayout(gd, 'scene2.camera', gd.layout.scene.camera)
                                              .then(() => {{ isUnderRelayout = false }}  )
                                            Plotly.relayout(gd, 'scene3.camera', gd.layout.scene.camera)
                                              .then(() => {{ isUnderRelayout = false }}  )
                                          }}
        
                                      isUnderRelayout = true;
                                    }})
                                    </script>'''.format(div_id=div_id)
                # merge everything
                div = '<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>' + div + js
                print(ds.shapenames[data[-1]])
                with open(
                        os.path.join(
                            path, "compare_{0}.html".format(
                                ds.shapenames[data[-1]])), "w") as text_file:
                    text_file.write(div)

    if compute_dist_to_gt:
        with open(
                os.path.join(path, "chamfer.csv"),
                "w",
        ) as f:
            if (with_opt):
                f.write(
                    "shape, chamfer_dist, chamfer scan dist, after opt chamfer dist, after opt chamfer scan dist\n"
                )
                for result in chamfer_results:
                    f.write("{}, {} , {}\n".format(result[0], result[1],
                                                   result[2], result[3],
                                                   result[4]))
            else:
                f.write("shape, chamfer_dist, chamfer scan dist\n")
                for result in chamfer_results:
                    f.write("{}, {} , {}\n".format(result[0], result[1],
                                                   result[2]))