Ejemplo n.º 1
0
def evaluate(network, experiment_directory, conf, checkpoint, split_file, epoch, resolution, uniform_grid):

    my_path = os.path.join(experiment_directory, 'evaluation', str(checkpoint))

    utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'evaluation'))
    utils.mkdir_ifnotexists(my_path)

    with open(split_file, "r") as f:
        split = json.load(f)

    ds = utils.get_class(conf.get_string('train.dataset'))(split=split, dataset_path=conf.get_string('train.dataset_path'), with_normals=True)

    total_files = len(ds)
    print("total files : {0}".format(total_files))
    counter = 0
    dataloader = torch.utils.data.DataLoader(ds, batch_size=1, shuffle=True, num_workers=1, drop_last=False, pin_memory=True)

    for (input_pc, normals, index) in dataloader:

        input_pc = input_pc.cuda().squeeze()
        normals = normals.cuda().squeeze()

        print(counter)
        counter = counter + 1

        network.train()

        latent = optimize_latent(input_pc, normals, conf, 800, network, lr=5e-3)

        all_latent = latent.repeat(input_pc.shape[0], 1)

        points = torch.cat([all_latent,input_pc], dim=-1)

        shapename = str.join('_', ds.get_info(index))

        with torch.no_grad():

            network.eval()

            plt.plot_surface(with_points=True,
                             points=points,
                             decoder=network,
                             latent=latent,
                             path=my_path,
                             epoch=epoch,
                             shapename=shapename,
                             resolution=resolution,
                             mc_value=0,
                             is_uniform_grid=uniform_grid,
                             verbose=True,
                             save_html=True,
                             save_ply=True,
                             overwrite=True,
                             connected=True)
Ejemplo n.º 2
0
def interpolate(network, interval, experiment_directory, checkpoint,
                split_file, epoch, resolution, uniform_grid):

    with open(split_file, "r") as f:
        split = json.load(f)

    ds = utils.get_class(conf.get_string('train.dataset'))(
        split=split,
        dataset_path=conf.get_string('train.dataset_path'),
        with_normals=True)

    points_1, normals_1, index_1 = ds[0]
    points_2, normals_2, index_2 = ds[1]

    pnts = torch.cat([points_1, points_2], dim=0).cuda()

    name_1 = str.join('_', ds.get_info(0))
    name_2 = str.join('_', ds.get_info(0))

    name = name_1 + '_and_' + name_2

    utils.mkdir_ifnotexists(os.path.join(experiment_directory, 'interpolate'))
    utils.mkdir_ifnotexists(
        os.path.join(experiment_directory, 'interpolate', str(checkpoint)))
    utils.mkdir_ifnotexists(
        os.path.join(experiment_directory, 'interpolate', str(checkpoint),
                     name))

    my_path = os.path.join(experiment_directory, 'interpolate',
                           str(checkpoint), name)

    latent_1 = optimize_latent(points_1.cuda(), normals_1.cuda(), conf, 800,
                               network, 5e-3)
    latent_2 = optimize_latent(points_2.cuda(), normals_2.cuda(), conf, 800,
                               network, 5e-3)

    pnts = torch.cat([latent_1.repeat(pnts.shape[0], 1), pnts], dim=-1)

    with torch.no_grad():
        network.eval()

        for alpha in np.linspace(0, 1, interval):

            latent = (latent_1 * (1 - alpha)) + (latent_2 * alpha)

            plt.plot_surface(with_points=False,
                             points=pnts,
                             decoder=network,
                             latent=latent,
                             path=my_path,
                             epoch=epoch,
                             shapename=str(alpha),
                             resolution=resolution,
                             mc_value=0,
                             is_uniform_grid=uniform_grid,
                             verbose=True,
                             save_html=False,
                             save_ply=True,
                             overwrite=True,
                             connected=True)
Ejemplo n.º 3
0
def evaluate(**kwargs):
    torch.set_default_dtype(torch.float32)

    conf = ConfigFactory.parse_file(kwargs['conf'])
    exps_folder_name = kwargs['exps_folder_name']
    evals_folder_name = kwargs['evals_folder_name']
    eval_rendering = kwargs['eval_rendering']
    eval_animation = kwargs['eval_animation']

    expname = conf.get_string('train.expname') + kwargs['expname']
    scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else conf.get_int(
        'dataset.scan_id', default=-1)
    if scan_id != -1:
        expname = expname + '_{0}'.format(scan_id)

    if kwargs['timestamp'] == 'latest':
        if os.path.exists(
                os.path.join('../', kwargs['exps_folder_name'], expname)):
            timestamps = os.listdir(
                os.path.join('../', kwargs['exps_folder_name'], expname))
            if (len(timestamps)) == 0:
                print('WRONG EXP FOLDER')
                exit()
            else:
                timestamp = sorted(timestamps)[-1]
        else:
            print('WRONG EXP FOLDER')
            exit()
    else:
        timestamp = kwargs['timestamp']

    utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name))
    expdir = os.path.join('../', exps_folder_name, expname)
    evaldir = os.path.join('../', evals_folder_name, expname)
    utils.mkdir_ifnotexists(evaldir)

    dataset_conf = conf.get_config('dataset')
    model = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model'),\
                                                                  id=scan_id, datadir=dataset_conf['data_dir'])
    if torch.cuda.is_available():
        model.cuda()

    if kwargs['scan_id'] != -1:
        dataset_conf['scan_id'] = kwargs['scan_id']
    eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(
        False, **dataset_conf)

    if eval_rendering:
        eval_dataloader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=eval_dataset.collate_fn)
        total_pixels = eval_dataset.total_pixels
        img_res = eval_dataset.img_res

    old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints')

    saved_model_state = torch.load(
        os.path.join(old_checkpnts_dir, 'ModelParameters',
                     str(kwargs['checkpoint']) + ".pth"))
    model.load_state_dict(saved_model_state["model_state_dict"])
    epoch = saved_model_state['epoch']

    ####################################################################################################################
    print("evaluating...")

    model.eval()

    detail_3dmm, detail_3dmm_subdivision_full = plt.get_displacement_mesh(
        model)
    detail_3dmm.export('{0}/Detailed_3dmm_{1}.obj'.format(evaldir, epoch),
                       'obj')
    detail_3dmm_subdivision_full.export(
        '{0}/Subdivide_full_{1}.obj'.format(evaldir, epoch), 'obj')

    if eval_animation:
        sdf_np0, sdf_np1 = plt.get_displacement_animation(model)
        np.save('{0}/Cropped_Detailed_sdf_{1}.npy'.format(evaldir, epoch),
                sdf_np0)
        np.save('{0}/Cropped_Subdivide_full_{1}.npy'.format(evaldir, epoch),
                sdf_np1)

    if eval_rendering:
        images_dir = '{0}/rendering'.format(evaldir)
        utils.mkdir_ifnotexists(images_dir)

        psnrs = []
        for data_index, (indices, model_input,
                         ground_truth) in enumerate(eval_dataloader):
            model_input["intrinsics"] = model_input["intrinsics"].cuda()
            model_input["uv"] = model_input["uv"].cuda()
            model_input["object_mask"] = model_input["object_mask"].cuda()
            model_input['pose'] = model_input['pose'].cuda()

            split = utils.split_input(model_input, total_pixels)
            res = []
            for s in split:
                out = model(s)
                res.append({
                    'rgb_values': out['rgb_values'].detach(),
                    'diffuse_values': out['diffuse_values'].detach(),
                    'specular_values': out['specular_values'].detach(),
                    'albedo_values': out['albedo_values'].detach(),
                })

            batch_size = ground_truth['rgb'].shape[0]
            model_outputs = utils.merge_output(res, total_pixels, batch_size)
            rgb_eval = model_outputs['rgb_values']
            rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3)
            rgb_eval = (rgb_eval + 1.) / 2.
            rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0]
            rgb_eval = rgb_eval.transpose(1, 2, 0)
            img = Image.fromarray((rgb_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}.png'.format(images_dir,
                                               '%03d' % indices[0]))

            diffuse_eval = model_outputs['diffuse_values']
            diffuse_eval = diffuse_eval.reshape(batch_size, total_pixels, 3)
            diffuse_eval = (diffuse_eval + 1.) / 2.
            diffuse_eval = plt.lin2img(diffuse_eval,
                                       img_res).detach().cpu().numpy()[0]
            diffuse_eval = diffuse_eval.transpose(1, 2, 0)
            img = Image.fromarray((diffuse_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}_diffuse.png'.format(images_dir,
                                                       '%03d' % indices[0]))

            specular_eval = model_outputs['specular_values']
            specular_eval = specular_eval.reshape(batch_size, total_pixels, 3)
            specular_eval = (specular_eval + 1.) / 2.
            specular_eval = plt.lin2img(specular_eval,
                                        img_res).detach().cpu().numpy()[0]
            specular_eval = specular_eval.transpose(1, 2, 0)
            img = Image.fromarray((specular_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}_specular.png'.format(
                images_dir, '%03d' % indices[0]))

            albedo_eval = model_outputs['albedo_values']
            albedo_eval = albedo_eval.reshape(batch_size, total_pixels, 3)
            albedo_eval = (albedo_eval + 1.) / 2.
            albedo_eval = plt.lin2img(albedo_eval,
                                      img_res).detach().cpu().numpy()[0]
            albedo_eval = albedo_eval.transpose(1, 2, 0)
            img = Image.fromarray((albedo_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}_albedo.png'.format(images_dir,
                                                      '%03d' % indices[0]))

            rgb_gt = ground_truth['rgb']
            rgb_gt = (rgb_gt + 1.) / 2.
            rgb_gt = plt.lin2img(rgb_gt, img_res).numpy()[0]
            rgb_gt = rgb_gt.transpose(1, 2, 0)

            mask = model_input['object_mask']
            mask = plt.lin2img(mask.unsqueeze(-1), img_res).cpu().numpy()[0]
            mask = mask.transpose(1, 2, 0)

            rgb_eval_masked = rgb_eval * mask
            rgb_gt_masked = rgb_gt * mask

            psnr = calculate_psnr(rgb_eval_masked, rgb_gt_masked, mask)
            psnrs.append(psnr)

        psnrs = np.array(psnrs).astype(np.float64)
        print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}".
              format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scan_id))
Ejemplo n.º 4
0
    def run(self):

        print("running")

        self.data = self.data.cuda()
        self.data.requires_grad_()

        if self.eval:

            print("evaluating epoch: {0}".format(self.startepoch))
            my_path = os.path.join(self.cur_exp_dir, 'evaluation', str(self.startepoch))

            utils.mkdir_ifnotexists(os.path.join(self.cur_exp_dir, 'evaluation'))
            utils.mkdir_ifnotexists(my_path)
            self.plot_shapes(epoch=self.startepoch, path=my_path, with_cuts=True)
            return

        print("training")

        for epoch in range(self.startepoch, self.nepochs + 1):

            indices = torch.tensor(np.random.choice(self.data.shape[0], self.points_batch, False))

            cur_data = self.data[indices]

            mnfld_pnts = cur_data[:, :self.d_in]
            mnfld_sigma = self.local_sigma[indices]

            if epoch % self.conf.get_int('train.checkpoint_frequency') == 0:
                print('saving checkpoint: ', epoch)
                self.save_checkpoints(epoch)
                print('plot validation epoch: ', epoch)
                self.plot_shapes(epoch)

            # change back to train mode
            self.network.train()
            self.adjust_learning_rate(epoch)

            nonmnfld_pnts = self.sampler.get_points(mnfld_pnts.unsqueeze(0), mnfld_sigma.unsqueeze(0)).squeeze()

            # forward pass

            mnfld_pred = self.network(mnfld_pnts)
            nonmnfld_pred = self.network(nonmnfld_pnts)

            # compute grad

            mnfld_grad = gradient(mnfld_pnts, mnfld_pred)
            nonmnfld_grad = gradient(nonmnfld_pnts, nonmnfld_pred)

            # manifold loss

            mnfld_loss = (mnfld_pred.abs()).mean()

            # eikonal loss

            grad_loss = ((nonmnfld_grad.norm(2, dim=-1) - 1) ** 2).mean()

            loss = mnfld_loss + self.grad_lambda * grad_loss

            # normals loss

            if self.with_normals:
                normals = cur_data[:, -self.d_in:]
                normals_loss = ((mnfld_grad - normals).abs()).norm(2, dim=1).mean()
                loss = loss + self.normals_lambda * normals_loss
            else:
                normals_loss = torch.zeros(1)

            # back propagation

            self.optimizer.zero_grad()

            loss.backward()

            self.optimizer.step()

            if epoch % self.conf.get_int('train.status_frequency') == 0:
                print('Train Epoch: [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}\tManifold loss: {:.6f}'
                    '\tGrad loss: {:.6f}\tNormals Loss: {:.6f}'.format(
                    epoch, self.nepochs, 100. * epoch / self.nepochs,
                    loss.item(), mnfld_loss.item(), grad_loss.item(), normals_loss.item()))
Ejemplo n.º 5
0
    def __init__(self, **kwargs):

        self.home_dir = os.path.abspath(os.pardir)

        # config setting

        if type(kwargs['conf']) == str:
            self.conf_filename = './reconstruction/' + kwargs['conf']
            self.conf = ConfigFactory.parse_file(self.conf_filename)
        else:
            self.conf = kwargs['conf']

        self.expname = kwargs['expname']

        # GPU settings

        self.GPU_INDEX = kwargs['gpu_index']

        if not self.GPU_INDEX == 'ignore':
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        self.num_of_gpus = torch.cuda.device_count()

        self.eval = kwargs['eval']

        # settings for loading an existing experiment

        if (kwargs['is_continue'] or self.eval) and kwargs['timestamp'] == 'latest':
            if os.path.exists(os.path.join(self.home_dir, 'exps', self.expname)):
                timestamps = os.listdir(os.path.join(self.home_dir, 'exps', self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue'] or self.eval

        self.exps_folder_name = 'exps'

        utils.mkdir_ifnotexists(utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name)))

        self.input_file = self.conf.get_string('train.input_path')
        self.data = utils.load_point_cloud_by_file_extension(self.input_file)

        sigma_set = []
        ptree = cKDTree(self.data)

        for p in np.array_split(self.data, 100, axis=0):
            d = ptree.query(p, 50 + 1)
            sigma_set.append(d[0][:, -1])

        sigmas = np.concatenate(sigma_set)
        self.local_sigma = torch.from_numpy(sigmas).float().cuda()

        self.expdir = utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name, self.expname))
        utils.mkdir_ifnotexists(self.expdir)

        if is_continue:
            self.timestamp = timestamp
        else:
            self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())

        self.cur_exp_dir = os.path.join(self.expdir, self.timestamp)
        utils.mkdir_ifnotexists(self.cur_exp_dir)

        self.plots_dir = os.path.join(self.cur_exp_dir, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"

        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir))
        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))

        self.nepochs = kwargs['nepochs']

        self.points_batch = kwargs['points_batch']

        self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma')
        self.sampler = Sampler.get_sampler(self.conf.get_string('network.sampler.sampler_type'))(self.global_sigma,
                                                                                                 self.local_sigma)
        self.grad_lambda = self.conf.get_float('network.loss.lambda')
        self.normals_lambda = self.conf.get_float('network.loss.normals_lambda')

        self.with_normals = self.normals_lambda > 0

        self.d_in = self.conf.get_int('train.d_in')

        self.network = utils.get_class(self.conf.get_string('train.network_class'))(d_in=self.d_in,
                                                                                    **self.conf.get_config(
                                                                                        'network.inputs'))

        if torch.cuda.is_available():
            self.network.cuda()

        self.lr_schedules = self.get_learning_rate_schedules(self.conf.get_list('train.learning_rate_schedule'))
        self.weight_decay = self.conf.get_float('train.weight_decay')

        self.startepoch = 0

        self.optimizer = torch.optim.Adam(
            [
                {
                    "params": self.network.parameters(),
                    "lr": self.lr_schedules[0].get_learning_rate(0),
                    "weight_decay": self.weight_decay
                },
            ])

        # if continue load checkpoints

        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints')

            saved_model_state = torch.load(
                os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
            self.network.load_state_dict(saved_model_state["model_state_dict"])

            data = torch.load(
                os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])
            self.startepoch = saved_model_state['epoch']
Ejemplo n.º 6
0
def evaluate(**kwargs):
    torch.set_default_dtype(torch.float32)

    conf = ConfigFactory.parse_file(kwargs['conf'])
    exps_folder_name = kwargs['exps_folder_name']
    evals_folder_name = kwargs['evals_folder_name']
    timestamp = '2020'
    checkpoint = '2000'

    expname = conf.get_string('train.expname')

    geometry_id = kwargs['geometry_id']
    appearance_id = kwargs['appearance_id']

    utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name))
    expdir_geometry = os.path.join('../', exps_folder_name, expname + '_{0}'.format(geometry_id))
    expdir_appearance = os.path.join('../', exps_folder_name, expname + '_{0}'.format(appearance_id))
    evaldir = os.path.join('../', evals_folder_name, expname + '_{0}_{1}'.format(geometry_id, appearance_id))
    utils.mkdir_ifnotexists(evaldir)

    model = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model'))
    if torch.cuda.is_available():
        model.cuda()

    # Load geometry network model
    old_checkpnts_dir = os.path.join(expdir_geometry, timestamp, 'checkpoints')
    saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', checkpoint + ".pth"))
    model.load_state_dict(saved_model_state["model_state_dict"])

    # Load rendering network model
    model_fake = utils.get_class(conf.get_string('train.model_class'))(conf=conf.get_config('model'))
    if torch.cuda.is_available():
        model_fake.cuda()
    old_checkpnts_dir = os.path.join(expdir_appearance, timestamp, 'checkpoints')
    saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', checkpoint + ".pth"))
    model_fake.load_state_dict(saved_model_state["model_state_dict"])

    model.rendering_network = model_fake.rendering_network

    dataset_conf = conf.get_config('dataset')
    dataset_conf['scan_id'] = geometry_id
    eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(False, **dataset_conf)

    eval_dataloader = torch.utils.data.DataLoader(eval_dataset,
                                                  batch_size=1,
                                                  shuffle=True,
                                                  collate_fn=eval_dataset.collate_fn
                                                  )
    total_pixels = eval_dataset.total_pixels
    img_res = eval_dataset.img_res

    ####################################################################################################################
    print("evaluating...")

    model.eval()

    gt_pose = eval_dataset.get_gt_pose(scaled=True).cuda()
    gt_quat = rend_util.rot_to_quat(gt_pose[:, :3, :3])
    gt_pose_vec = torch.cat([gt_quat, gt_pose[:, :3, 3]], 1)

    indices_all = [11, 16, 34, 28, 11]
    pose = gt_pose_vec[indices_all, :]
    t_in = np.array([0, 2, 3, 5, 6]).astype(np.float32)

    n_inter = 5
    t_out = np.linspace(t_in[0], t_in[-1], n_inter * t_in[-1]).astype(np.float32)

    scales = np.array([4.2, 4.2, 3.8, 3.8, 4.2]).astype(np.float32)

    s_new = CubicSpline(t_in, scales, bc_type='periodic')
    s_new = s_new(t_out)

    q_new = CubicSpline(t_in, pose[:, :4].detach().cpu().numpy(), bc_type='periodic')
    q_new = q_new(t_out)
    q_new = q_new / np.linalg.norm(q_new, 2, 1)[:, None]
    q_new = torch.from_numpy(q_new).cuda().float()

    images_dir = '{0}/novel_views_rendering'.format(evaldir)
    utils.mkdir_ifnotexists(images_dir)

    indices, model_input, ground_truth = next(iter(eval_dataloader))

    for i, (new_q, scale) in enumerate(zip(q_new, s_new)):
        torch.cuda.empty_cache()

        new_q = new_q.unsqueeze(0)
        new_t = -rend_util.quat_to_rot(new_q)[:, :, 2] * scale

        new_p = torch.eye(4).float().cuda().unsqueeze(0)
        new_p[:, :3, :3] = rend_util.quat_to_rot(new_q)
        new_p[:, :3, 3] = new_t

        sample = {
            "object_mask": torch.zeros_like(model_input['object_mask']).cuda().bool(),
            "uv": model_input['uv'].cuda(),
            "intrinsics": model_input['intrinsics'].cuda(),
            "pose": new_p
        }

        split = utils.split_input(sample, total_pixels)
        res = []
        for s in split:
            out = model(s)
            res.append({
                'rgb_values': out['rgb_values'].detach(),
            })

        batch_size = 1
        model_outputs = utils.merge_output(res, total_pixels, batch_size)
        rgb_eval = model_outputs['rgb_values']
        rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3)

        rgb_eval = (rgb_eval + 1.) / 2.
        rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0]
        rgb_eval = rgb_eval.transpose(1, 2, 0)
        img = Image.fromarray((rgb_eval * 255).astype(np.uint8))
        img.save('{0}/eval_{1}.png'.format(images_dir,'%03d' % i))
Ejemplo n.º 7
0
    def __init__(self, **kwargs):
        torch.set_default_dtype(torch.float32)
        torch.set_num_threads(1)

        self.conf = ConfigFactory.parse_file(kwargs['conf'])
        self.batch_size = kwargs['batch_size']
        self.nepochs = kwargs['nepochs']
        self.exps_folder_name = kwargs['exps_folder_name']
        self.GPU_INDEX = kwargs['gpu_index']
        self.train_cameras = kwargs['train_cameras']

        self.expname = self.conf.get_string(
            'train.expname') + kwargs['expname']
        scan_id = kwargs[
            'scan_id'] if kwargs['scan_id'] != -1 else self.conf.get_int(
                'dataset.scan_id', default=-1)
        if scan_id != -1:
            self.expname = self.expname + '_{0}'.format(scan_id)

        if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
            if os.path.exists(
                    os.path.join('../', kwargs['exps_folder_name'],
                                 self.expname)):
                timestamps = os.listdir(
                    os.path.join('../', kwargs['exps_folder_name'],
                                 self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue']

        utils.mkdir_ifnotexists(os.path.join('../', self.exps_folder_name))
        self.expdir = os.path.join('../', self.exps_folder_name, self.expname)
        utils.mkdir_ifnotexists(self.expdir)
        self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
        utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp))

        self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        # create checkpoints dirs
        self.checkpoints_path = os.path.join(self.expdir, self.timestamp,
                                             'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)
        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"
        self.scheduler_params_subdir = "SchedulerParameters"

        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.model_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.scheduler_params_subdir))

        if self.train_cameras:
            self.optimizer_cam_params_subdir = "OptimizerCamParameters"
            self.cam_params_subdir = "CamParameters"

            utils.mkdir_ifnotexists(
                os.path.join(self.checkpoints_path,
                             self.optimizer_cam_params_subdir))
            utils.mkdir_ifnotexists(
                os.path.join(self.checkpoints_path, self.cam_params_subdir))

        os.system("""cp -r {0} "{1}" """.format(
            kwargs['conf'],
            os.path.join(self.expdir, self.timestamp, 'runconf.conf')))

        if (not self.GPU_INDEX == 'ignore'):
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        print('shell command : {0}'.format(' '.join(sys.argv)))

        print('Loading data ...')

        dataset_conf = self.conf.get_config('dataset')
        if kwargs['scan_id'] != -1:
            dataset_conf['scan_id'] = kwargs['scan_id']

        # add train_set instead of training on all image
        try:
            dataset_conf['train_set'] = self.conf.get_list('train.train_set')
        except:
            dataset_conf['train_set'] = None

        try:
            dataset_conf['test_set'] = self.conf.get_list('train.test_set')
        except:
            dataset_conf['test_set'] = None

        self.train_dataset = utils.get_class(
            self.conf.get_string('train.dataset_class'))(self.train_cameras,
                                                         **dataset_conf)

        print('Finish loading data ...')

        self.train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self.train_dataset.collate_fn)
        self.plot_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.conf.get_int('plot.plot_nimgs'),
            shuffle=True,
            collate_fn=self.train_dataset.collate_fn)

        self.model = utils.get_class(
            self.conf.get_string('train.model_class'))(
                conf=self.conf.get_config('model'))
        if torch.cuda.is_available():
            self.model.cuda()

        self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(
            **self.conf.get_config('loss'))

        self.lr = self.conf.get_float('train.learning_rate')
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.sched_milestones = self.conf.get_list('train.sched_milestones',
                                                   default=[])
        self.sched_factor = self.conf.get_float('train.sched_factor',
                                                default=0.0)
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer, self.sched_milestones, gamma=self.sched_factor)

        # settings for camera optimization
        if self.train_cameras:
            num_images = len(self.train_dataset)
            self.pose_vecs = torch.nn.Embedding(num_images, 7,
                                                sparse=True).cuda()
            self.pose_vecs.weight.data.copy_(
                self.train_dataset.get_pose_init())

            self.optimizer_cam = torch.optim.SparseAdam(
                self.pose_vecs.parameters(),
                self.conf.get_float('train.learning_rate_cam'))

        self.start_epoch = 0
        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp,
                                             'checkpoints')

            saved_model_state = torch.load(
                os.path.join(old_checkpnts_dir, 'ModelParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.model.load_state_dict(saved_model_state["model_state_dict"])
            self.start_epoch = saved_model_state['epoch']

            data = torch.load(
                os.path.join(old_checkpnts_dir, 'OptimizerParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])

            data = torch.load(
                os.path.join(old_checkpnts_dir, self.scheduler_params_subdir,
                             str(kwargs['checkpoint']) + ".pth"))
            self.scheduler.load_state_dict(data["scheduler_state_dict"])

            if self.train_cameras:
                data = torch.load(
                    os.path.join(old_checkpnts_dir,
                                 self.optimizer_cam_params_subdir,
                                 str(kwargs['checkpoint']) + ".pth"))
                self.optimizer_cam.load_state_dict(
                    data["optimizer_cam_state_dict"])

                data = torch.load(
                    os.path.join(old_checkpnts_dir, self.cam_params_subdir,
                                 str(kwargs['checkpoint']) + ".pth"))
                self.pose_vecs.load_state_dict(data["pose_vecs_state_dict"])

        self.num_pixels = self.conf.get_int('train.num_pixels')
        self.total_pixels = self.train_dataset.total_pixels
        self.img_res = self.train_dataset.img_res
        self.n_batches = len(self.train_dataloader)
        self.plot_freq = self.conf.get_int('train.plot_freq')
        self.plot_conf = self.conf.get_config('plot')

        self.alpha_milestones = self.conf.get_list('train.alpha_milestones',
                                                   default=[])
        self.alpha_factor = self.conf.get_float('train.alpha_factor',
                                                default=0.0)
        for acc in self.alpha_milestones:
            if self.start_epoch > acc:
                self.loss.alpha = self.loss.alpha * self.alpha_factor
Ejemplo n.º 8
0
        scale = 1

        # os.chdir('/home/atzmonm/data/')
        for ds,cat_det in train_split['scans'].items():
            if names and ds not in names:
                continue
            print("ds :{0} , a:{1}".format(ds,countera))
            countera = countera + 1
            counterb = 0

            for cat,shapes in cat_det.items():
                print("cat {0} : b{1}".format(cat,counterb))
                counterb = counterb + 1
                source = os.path.abspath(os.path.join(args.src_path, 'scans', ds, cat))
                output = os.path.abspath(os.path.join(args.out_path, 'dfaust_processed'))
                utils.mkdir_ifnotexists(output)
                utils.mkdir_ifnotexists(os.path.join(output, ds))
                utils.mkdir_ifnotexists(os.path.join(output, ds, cat))
                counterc = 0

                for item,shape in enumerate(shapes):
                    print("item {0} : c{1}".format(cat, counterc))
                    counterc = counterc + 1
                    output_file = os.path.join(output,ds,cat,shape)
                    print (output_file)
                    if not (args.skip and os.path.isfile(output_file + '.npy')):
                        print ('loading : {0}'.format(os.path.join(source,shape)))
                        mesh = trimesh.load(os.path.join(source,shape) + '.ply')
                        sample = sample_surface(mesh,SAMPLES)
                        pnts = sample[0]
                        normals = mesh.face_normals[sample[1]]
Ejemplo n.º 9
0
    def __init__(self, **kwargs):

        # config setting

        self.home_dir = os.path.abspath(os.pardir)

        if type(kwargs['conf']) == str:
            self.conf_filename = os.path.abspath(kwargs['conf'])
            self.conf = ConfigFactory.parse_file(self.conf_filename)
        else:
            self.conf = kwargs['conf']

        self.expname = kwargs['expname']

        # GPU settings

        self.GPU_INDEX = kwargs['gpu_index']

        if not self.GPU_INDEX == 'ignore':
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        self.num_of_gpus = torch.cuda.device_count()

        # settings for loading an existing experiment

        if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
            if os.path.exists(os.path.join(self.home_dir, 'exps',
                                           self.expname)):
                timestamps = os.listdir(
                    os.path.join(self.home_dir, 'exps', self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue']

        self.exps_folder_name = 'exps'

        utils.mkdir_ifnotexists(
            utils.concat_home_dir(
                os.path.join(self.home_dir, self.exps_folder_name)))

        self.expdir = utils.concat_home_dir(
            os.path.join(self.home_dir, self.exps_folder_name, self.expname))
        utils.mkdir_ifnotexists(self.expdir)

        if is_continue:
            self.timestamp = timestamp
        else:
            self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())

        self.cur_exp_dir = self.timestamp
        utils.mkdir_ifnotexists(os.path.join(self.expdir, self.cur_exp_dir))

        self.plots_dir = os.path.join(self.expdir, self.cur_exp_dir, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir,
                                             'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir,
                                             'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"
        self.latent_codes_subdir = "LatentCodes"

        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.model_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.latent_codes_subdir))

        self.nepochs = kwargs['nepochs']

        self.batch_size = kwargs['batch_size']

        if self.num_of_gpus > 0:
            self.batch_size *= self.num_of_gpus

        self.parallel = self.num_of_gpus > 1

        self.global_sigma = self.conf.get_float(
            'network.sampler.properties.global_sigma')
        self.local_sigma = self.conf.get_float(
            'network.sampler.properties.local_sigma')
        self.sampler = Sampler.get_sampler(
            self.conf.get_string('network.sampler.sampler_type'))(
                self.global_sigma, self.local_sigma)

        train_split_file = os.path.abspath(kwargs['split_file'])
        print(f'Loading split file {train_split_file}')
        with open(train_split_file, "r") as f:
            train_split = json.load(f)
        print(f'Size of the split: {len(train_split)} samples')

        self.d_in = self.conf.get_int('train.d_in')

        # latent preprocessing

        self.latent_size = self.conf.get_int('train.latent_size')

        self.latent_lambda = self.conf.get_float('network.loss.latent_lambda')
        self.grad_lambda = self.conf.get_float('network.loss.lambda')
        self.normals_lambda = self.conf.get_float(
            'network.loss.normals_lambda')

        self.with_normals = self.normals_lambda > 0

        self.ds = utils.get_class(self.conf.get_string('train.dataset'))(
            split=train_split,
            with_normals=self.with_normals,
            dataset_path=self.conf.get_string('train.dataset_path'),
            points_batch=kwargs['points_batch'],
        )

        self.num_scenes = len(self.ds)

        self.train_dataloader = torch.utils.data.DataLoader(
            self.ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=kwargs['threads'],
            drop_last=True,
            pin_memory=True)
        self.eval_dataloader = torch.utils.data.DataLoader(self.ds,
                                                           batch_size=1,
                                                           shuffle=False,
                                                           num_workers=0,
                                                           drop_last=True)

        self.network = utils.get_class(
            self.conf.get_string('train.network_class'))(
                d_in=(self.d_in + self.latent_size),
                **self.conf.get_config('network.inputs'))

        if self.parallel:
            self.network = torch.nn.DataParallel(self.network)

        if torch.cuda.is_available():
            self.network.cuda()

        self.lr_schedules = self.get_learning_rate_schedules(
            self.conf.get_list('train.learning_rate_schedule'))
        self.weight_decay = self.conf.get_float('train.weight_decay')

        # optimizer and latent settings

        self.startepoch = 0

        self.lat_vecs = utils.to_cuda(
            torch.zeros(self.num_scenes, self.latent_size))
        self.lat_vecs.requires_grad_()

        self.optimizer = torch.optim.Adam([
            {
                "params": self.network.parameters(),
                "lr": self.lr_schedules[0].get_learning_rate(0),
                "weight_decay": self.weight_decay
            },
            {
                "params": self.lat_vecs,
                "lr": self.lr_schedules[1].get_learning_rate(0)
            },
        ])

        # if continue load checkpoints

        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp,
                                             'checkpoints')

            data = torch.load(
                os.path.join(old_checkpnts_dir, self.latent_codes_subdir,
                             str(kwargs['checkpoint']) + '.pth'))

            self.lat_vecs = utils.to_cuda(data["latent_codes"])

            saved_model_state = torch.load(
                os.path.join(old_checkpnts_dir, 'ModelParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.network.load_state_dict(saved_model_state["model_state_dict"])

            data = torch.load(
                os.path.join(old_checkpnts_dir, 'OptimizerParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])
            self.startepoch = saved_model_state['epoch']
Ejemplo n.º 10
0
def evaluate(**kwargs):
    torch.set_default_dtype(torch.float32)

    conf = ConfigFactory.parse_file(kwargs['conf'])
    exps_folder_name = kwargs['exps_folder_name']
    evals_folder_name = kwargs['evals_folder_name']
    eval_cameras = kwargs['eval_cameras']
    eval_rendering = kwargs['eval_rendering']

    expname = conf.get_string('train.expname') + kwargs['expname']
    scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else conf.get_int(
        'dataset.scan_id', default=-1)
    if scan_id != -1:
        expname = expname + '_{0}'.format(scan_id)

    if kwargs['timestamp'] == 'latest':
        if os.path.exists(
                os.path.join('../', kwargs['exps_folder_name'], expname)):
            timestamps = os.listdir(
                os.path.join('../', kwargs['exps_folder_name'], expname))
            if (len(timestamps)) == 0:
                print('WRONG EXP FOLDER')
                exit()
            else:
                timestamp = sorted(timestamps)[-1]
        else:
            print('WRONG EXP FOLDER')
            exit()
    else:
        timestamp = kwargs['timestamp']

    utils.mkdir_ifnotexists(os.path.join('../', evals_folder_name))
    expdir = os.path.join('../', exps_folder_name, expname)
    evaldir = os.path.join('../', evals_folder_name, expname)
    utils.mkdir_ifnotexists(evaldir)

    model = utils.get_class(
        conf.get_string('train.model_class'))(conf=conf.get_config('model'))
    if torch.cuda.is_available():
        model.cuda()

    dataset_conf = conf.get_config('dataset')
    if kwargs['scan_id'] != -1:
        dataset_conf['scan_id'] = kwargs['scan_id']
    eval_dataset = utils.get_class(conf.get_string('train.dataset_class'))(
        eval_cameras, **dataset_conf)

    # settings for camera optimization
    scale_mat = eval_dataset.get_scale_mat()
    if eval_cameras:
        num_images = len(eval_dataset)
        pose_vecs = torch.nn.Embedding(num_images, 7, sparse=True).cuda()
        pose_vecs.weight.data.copy_(eval_dataset.get_pose_init())

        gt_pose = eval_dataset.get_gt_pose()

    if eval_rendering:
        eval_dataloader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=eval_dataset.collate_fn)
        total_pixels = eval_dataset.total_pixels
        img_res = eval_dataset.img_res

    old_checkpnts_dir = os.path.join(expdir, timestamp, 'checkpoints')

    saved_model_state = torch.load(
        os.path.join(old_checkpnts_dir, 'ModelParameters',
                     str(kwargs['checkpoint']) + ".pth"))
    model.load_state_dict(saved_model_state["model_state_dict"])
    epoch = saved_model_state['epoch']

    if eval_cameras:
        data = torch.load(
            os.path.join(old_checkpnts_dir, 'CamParameters',
                         str(kwargs['checkpoint']) + ".pth"))
        pose_vecs.load_state_dict(data["pose_vecs_state_dict"])

    ####################################################################################################################
    print("evaluating...")

    model.eval()
    if eval_cameras:
        pose_vecs.eval()

    with torch.no_grad():
        if eval_cameras:
            gt_Rs = gt_pose[:, :3, :3].double()
            gt_ts = gt_pose[:, :3, 3].double()

            pred_Rs = rend_util.quat_to_rot(
                pose_vecs.weight.data[:, :4]).cpu().double()
            pred_ts = pose_vecs.weight.data[:, 4:].cpu().double()

            R_opt, t_opt, c_opt, R_fixed, t_fixed = get_cameras_accuracy(
                pred_Rs, gt_Rs, pred_ts, gt_ts)

            cams_transformation = np.eye(4, dtype=np.double)
            cams_transformation[:3, :3] = c_opt * R_opt
            cams_transformation[:3, 3] = t_opt

        mesh = plt.get_surface_high_res_mesh(
            sdf=lambda x: model.implicit_network(x)[:, 0],
            resolution=kwargs['resolution'])

        # Transform to world coordinates
        if eval_cameras:
            mesh.apply_transform(cams_transformation)
        else:
            mesh.apply_transform(scale_mat)

        # Taking the biggest connected component
        components = mesh.split(only_watertight=False)
        areas = np.array([c.area for c in components], dtype=np.float)
        mesh_clean = components[areas.argmax()]
        mesh_clean.export(
            '{0}/surface_world_coordinates_{1}.ply'.format(evaldir, epoch),
            'ply')

    if eval_rendering:
        images_dir = '{0}/rendering'.format(evaldir)
        utils.mkdir_ifnotexists(images_dir)

        psnrs = []
        for data_index, (indices, model_input,
                         ground_truth) in enumerate(eval_dataloader):
            model_input["intrinsics"] = model_input["intrinsics"].cuda()
            model_input["uv"] = model_input["uv"].cuda()
            model_input["object_mask"] = model_input["object_mask"].cuda()

            if eval_cameras:
                pose_input = pose_vecs(indices.cuda())
                model_input['pose'] = pose_input
            else:
                model_input['pose'] = model_input['pose'].cuda()

            split = utils.split_input(model_input, total_pixels)
            res = []
            for s in split:
                out = model(s)
                res.append({
                    'rgb_values': out['rgb_values'].detach(),
                })

            batch_size = ground_truth['rgb'].shape[0]
            model_outputs = utils.merge_output(res, total_pixels, batch_size)
            rgb_eval = model_outputs['rgb_values']
            rgb_eval = rgb_eval.reshape(batch_size, total_pixels, 3)

            rgb_eval = (rgb_eval + 1.) / 2.
            rgb_eval = plt.lin2img(rgb_eval, img_res).detach().cpu().numpy()[0]
            rgb_eval = rgb_eval.transpose(1, 2, 0)
            img = Image.fromarray((rgb_eval * 255).astype(np.uint8))
            img.save('{0}/eval_{1}.png'.format(images_dir,
                                               '%03d' % indices[0]))

            rgb_gt = ground_truth['rgb']
            rgb_gt = (rgb_gt + 1.) / 2.
            rgb_gt = plt.lin2img(rgb_gt, img_res).numpy()[0]
            rgb_gt = rgb_gt.transpose(1, 2, 0)

            mask = model_input['object_mask']
            mask = plt.lin2img(mask.unsqueeze(-1), img_res).cpu().numpy()[0]
            mask = mask.transpose(1, 2, 0)

            rgb_eval_masked = rgb_eval * mask
            rgb_gt_masked = rgb_gt * mask

            psnr = calculate_psnr(rgb_eval_masked, rgb_gt_masked, mask)
            psnrs.append(psnr)

        psnrs = np.array(psnrs).astype(np.float64)
        print("RENDERING EVALUATION {2}: psnr mean = {0} ; psnr std = {1}".
              format("%.2f" % psnrs.mean(), "%.2f" % psnrs.std(), scan_id))
Ejemplo n.º 11
0
    opt = parser.parse_args()

    with open(opt.split, "r") as f:
        train_split = json.load(f)

    shapeindex = opt.shapeindex

    global_shape_index = 0
    for human, scans in train_split['scans'].items():
        for pose, shapes in scans.items():

            source = '{0}/{1}/{2}/{3}'.format(opt.datapath, 'scans', human,
                                              pose)
            output = opt.datapath + '_processed'
            utils.mkdir_ifnotexists(output)
            utils.mkdir_ifnotexists(os.path.join(output, human))
            utils.mkdir_ifnotexists(os.path.join(output, human, pose))

            for shape in shapes:

                if (shapeindex == global_shape_index or shapeindex == -1):
                    print("found!")
                    output_file = os.path.join(output, human, pose, shape)
                    print(output_file)
                    if (not opt.skip
                            or not os.path.isfile(output_file +
                                                  '_dist_triangle.npy')):

                        print('loading : {0}'.format(
                            os.path.join(source, shape)))
Ejemplo n.º 12
0
def evaluate(network, exps_dir, experiment_name, timestamp, split_filename,
             epoch, conf, with_opt, resolution, compute_dist_to_gt):

    utils.mkdir_ifnotexists(
        os.path.join('../', exps_dir, experiment_name, timestamp,
                     'evaluation'))
    utils.mkdir_ifnotexists(
        os.path.join('../', exps_dir, experiment_name, timestamp, 'evaluation',
                     split_filename.split('/')[-1].split('.json')[0]))
    path = os.path.join('../', exps_dir, experiment_name, timestamp,
                        'evaluation',
                        split_filename.split('/')[-1].split('.json')[0],
                        str(epoch))
    utils.mkdir_ifnotexists(path)

    dataset_path = conf.get_string('train.dataset_path')
    train_data_split = conf.get_string('train.data_split')
    latent_size = conf.get_int('train.latent_size')

    if (train_data_split == 'none'):
        ds = ReconDataSet(split=None,
                          dataset_path=dataset_path,
                          dist_file_name=None)
    else:
        dist_file_name = conf.get_string('train.dist_file_name')
        with open(split_filename, "r") as f:
            split = json.load(f)

        chamfer_results = []
        plot_cmpr = True
        ds = DFaustDataSet(split=split,
                           dataset_path=dataset_path,
                           dist_file_name=dist_file_name,
                           with_gt=True)
        total_files = len(ds)
        logging.info("total files : {0}".format(total_files))
    counter = 0
    dataloader = torch.utils.data.DataLoader(ds,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=1,
                                             drop_last=False,
                                             pin_memory=True)

    for data in dataloader:

        counter = counter + 1

        logging.info("evaluating " + ds.npyfiles_mnfld[data[-1]])

        input_pc = data[0].cuda()
        if latent_size > 0:
            latent = network.encoder(input_pc)
            if (type(latent) is tuple):
                latent = latent[0]
            points = torch.cat([
                latent.unsqueeze(1).repeat(1, input_pc.shape[1], 1), input_pc
            ],
                               dim=-1)[0]
        else:
            latent = None
            points = input_pc[0]

        reconstruction = plt.plot_surface(
            with_points=False,
            points=points,
            decoder=network.decoder,
            latent=latent,
            path=path,
            epoch=epoch,
            in_epoch=ds.npyfiles_mnfld[data[-1].item()].split('/')[-3] + '_' +
            ds.npyfiles_mnfld[data[-1].item()].split('/')[-1].split('.npy')[0]
            + '_before',
            shapefile=ds.npyfiles_mnfld[data[-1].item()],
            resolution=resolution,
            mc_value=0,
            is_uniform_grid=True,
            verbose=True,
            save_html=False,
            save_ply=True,
            overwrite=True)
        if (with_opt):
            recon_after_latentopt = optimize_latent(latent, ds, data[-1],
                                                    network.decoder, path,
                                                    epoch, resolution, conf)

        if compute_dist_to_gt:
            gt_mesh_filename = ds.gt_files[data[-1]]
            normalization_params_filename = ds.normalization_files[data[-1]]

            logging.debug("normalization params are " +
                          normalization_params_filename)

            ground_truth_points = trimesh.Trimesh(
                trimesh.sample.sample_surface(trimesh.load(gt_mesh_filename),
                                              30000)[0])

            normalization_params = np.load(normalization_params_filename,
                                           allow_pickle=True)

            scale = normalization_params.item()['scale']
            center = normalization_params.item()['center']

            chamfer_dist = utils.compute_trimesh_chamfer(
                gt_points=ground_truth_points,
                gen_mesh=reconstruction,
                offset=-center,
                scale=1. / scale,
            )

            chamfer_dist_scan = utils.compute_trimesh_chamfer(
                gt_points=trimesh.Trimesh(input_pc[0].cpu().numpy()),
                gen_mesh=reconstruction,
                offset=0,
                scale=1.,
                one_side=True)

            logging.debug("chamfer distance: " + str(chamfer_dist))

            if (with_opt):
                chamfer_dist_after_opt = utils.compute_trimesh_chamfer(
                    gt_points=ground_truth_points,
                    gen_mesh=recon_after_latentopt,
                    offset=-center,
                    scale=1. / scale,
                )

                chamfer_dist_scan_after_opt = utils.compute_trimesh_chamfer(
                    gt_points=trimesh.Trimesh(input_pc[0].cpu().numpy()),
                    gen_mesh=recon_after_latentopt,
                    offset=0,
                    scale=1.,
                    one_side=True)

                chamfer_results.append(
                    (ds.gt_files[data[-1]], chamfer_dist, chamfer_dist_scan,
                     chamfer_dist_after_opt, chamfer_dist_scan_after_opt))
            else:
                chamfer_results.append(
                    (ds.gt_files[data[-1]], chamfer_dist, chamfer_dist_scan))

            if (plot_cmpr):
                if (with_opt):
                    fig = make_subplots(rows=2,
                                        cols=2,
                                        specs=[[{
                                            "type": "scene"
                                        }, {
                                            "type": "scene"
                                        }],
                                               [{
                                                   "type": "scene"
                                               }, {
                                                   "type": "scene"
                                               }]],
                                        subplot_titles=[
                                            "Input", "Registration", "Ours",
                                            "Ours after opt"
                                        ])

                else:
                    fig = make_subplots(rows=1,
                                        cols=3,
                                        specs=[[{
                                            "type": "scene"
                                        }, {
                                            "type": "scene"
                                        }, {
                                            "type": "scene"
                                        }]],
                                        subplot_titles=("input pc", "Ours",
                                                        "Registration"))

                fig.layout.scene.update(
                    dict(xaxis=dict(range=[-1.5, 1.5], autorange=False),
                         yaxis=dict(range=[-1.5, 1.5], autorange=False),
                         zaxis=dict(range=[-1.5, 1.5], autorange=False),
                         aspectratio=dict(x=1, y=1, z=1)))
                fig.layout.scene2.update(
                    dict(xaxis=dict(range=[-1.5, 1.5], autorange=False),
                         yaxis=dict(range=[-1.5, 1.5], autorange=False),
                         zaxis=dict(range=[-1.5, 1.5], autorange=False),
                         aspectratio=dict(x=1, y=1, z=1)))
                fig.layout.scene3.update(
                    dict(xaxis=dict(range=[-1.5, 1.5], autorange=False),
                         yaxis=dict(range=[-1.5, 1.5], autorange=False),
                         zaxis=dict(range=[-1.5, 1.5], autorange=False),
                         aspectratio=dict(x=1, y=1, z=1)))
                if (with_opt):
                    fig.layout.scene4.update(
                        dict(xaxis=dict(range=[-1.5, 1.5], autorange=False),
                             yaxis=dict(range=[-1.5, 1.5], autorange=False),
                             zaxis=dict(range=[-1.5, 1.5], autorange=False),
                             aspectratio=dict(x=1, y=1, z=1)))

                scan_mesh = trimesh.load(ds.scans_files[data[-1]])

                scan_mesh.vertices = scan_mesh.vertices - center

                def tri_indices(simplices):
                    return ([triplet[c] for triplet in simplices]
                            for c in range(3))

                I, J, K = tri_indices(scan_mesh.faces)
                color = '#ffffff'
                trace = go.Mesh3d(x=scan_mesh.vertices[:, 0],
                                  y=scan_mesh.vertices[:, 1],
                                  z=scan_mesh.vertices[:, 2],
                                  i=I,
                                  j=J,
                                  k=K,
                                  name='scan',
                                  color=color,
                                  opacity=1.0,
                                  flatshading=False,
                                  lighting=dict(diffuse=1,
                                                ambient=0,
                                                specular=0),
                                  lightposition=dict(x=0, y=0, z=-1))
                fig.add_trace(trace, row=1, col=1)

                I, J, K = tri_indices(reconstruction.faces)
                color = '#ffffff'
                trace = go.Mesh3d(x=reconstruction.vertices[:, 0],
                                  y=reconstruction.vertices[:, 1],
                                  z=reconstruction.vertices[:, 2],
                                  i=I,
                                  j=J,
                                  k=K,
                                  name='our',
                                  color=color,
                                  opacity=1.0,
                                  flatshading=False,
                                  lighting=dict(diffuse=1,
                                                ambient=0,
                                                specular=0),
                                  lightposition=dict(x=0, y=0, z=-1))
                if (with_opt):
                    fig.add_trace(trace, row=2, col=1)

                    I, J, K = tri_indices(recon_after_latentopt.faces)
                    color = '#ffffff'
                    trace = go.Mesh3d(x=recon_after_latentopt.vertices[:, 0],
                                      y=recon_after_latentopt.vertices[:, 1],
                                      z=recon_after_latentopt.vertices[:, 2],
                                      i=I,
                                      j=J,
                                      k=K,
                                      name='our_after_opt',
                                      color=color,
                                      opacity=1.0,
                                      flatshading=False,
                                      lighting=dict(diffuse=1,
                                                    ambient=0,
                                                    specular=0),
                                      lightposition=dict(x=0, y=0, z=-1))
                    fig.add_trace(trace, row=2, col=2)
                else:
                    fig.add_trace(trace, row=1, col=2)

                gtmesh = trimesh.load(gt_mesh_filename)
                gtmesh.vertices = gtmesh.vertices - center
                I, J, K = tri_indices(gtmesh.faces)
                trace = go.Mesh3d(x=gtmesh.vertices[:, 0],
                                  y=gtmesh.vertices[:, 1],
                                  z=gtmesh.vertices[:, 2],
                                  i=I,
                                  j=J,
                                  k=K,
                                  name='gt',
                                  color=color,
                                  opacity=1.0,
                                  flatshading=False,
                                  lighting=dict(diffuse=1,
                                                ambient=0,
                                                specular=0),
                                  lightposition=dict(x=0, y=0, z=-1))
                if (with_opt):
                    fig.add_trace(trace, row=1, col=2)
                else:
                    fig.add_trace(trace, row=1, col=3)

                div = offline.plot(fig,
                                   include_plotlyjs=False,
                                   output_type='div',
                                   auto_open=False)
                div_id = div.split('=')[1].split()[0].replace("'", "").replace(
                    '"', '')
                if (with_opt):
                    js = '''
                                                    <script>
                                                    var gd = document.getElementById('{div_id}');
                                                    var isUnderRelayout = false
    
                                                    gd.on('plotly_relayout', () => {{
                                                      console.log('relayout', isUnderRelayout)
                                                      if (!isUnderRelayout) {{
                                                            Plotly.relayout(gd, 'scene2.camera', gd.layout.scene.camera)
                                                              .then(() => {{ isUnderRelayout = false }}  )
                                                            Plotly.relayout(gd, 'scene3.camera', gd.layout.scene.camera)
                                                              .then(() => {{ isUnderRelayout = false }}  )
                                                            Plotly.relayout(gd, 'scene4.camera', gd.layout.scene.camera)
                                                              .then(() => {{ isUnderRelayout = false }}  )
                                                          }}
    
                                                      isUnderRelayout = true;
                                                    }})
                                                    </script>'''.format(
                        div_id=div_id)
                else:
                    js = '''
                                    <script>
                                    var gd = document.getElementById('{div_id}');
                                    var isUnderRelayout = false
        
                                    gd.on('plotly_relayout', () => {{
                                      console.log('relayout', isUnderRelayout)
                                      if (!isUnderRelayout) {{
                                            Plotly.relayout(gd, 'scene2.camera', gd.layout.scene.camera)
                                              .then(() => {{ isUnderRelayout = false }}  )
                                            Plotly.relayout(gd, 'scene3.camera', gd.layout.scene.camera)
                                              .then(() => {{ isUnderRelayout = false }}  )
                                          }}
        
                                      isUnderRelayout = true;
                                    }})
                                    </script>'''.format(div_id=div_id)
                # merge everything
                div = '<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>' + div + js
                print(ds.shapenames[data[-1]])
                with open(
                        os.path.join(
                            path, "compare_{0}.html".format(
                                ds.shapenames[data[-1]])), "w") as text_file:
                    text_file.write(div)

    if compute_dist_to_gt:
        with open(
                os.path.join(path, "chamfer.csv"),
                "w",
        ) as f:
            if (with_opt):
                f.write(
                    "shape, chamfer_dist, chamfer scan dist, after opt chamfer dist, after opt chamfer scan dist\n"
                )
                for result in chamfer_results:
                    f.write("{}, {} , {}\n".format(result[0], result[1],
                                                   result[2], result[3],
                                                   result[4]))
            else:
                f.write("shape, chamfer_dist, chamfer scan dist\n")
                for result in chamfer_results:
                    f.write("{}, {} , {}\n".format(result[0], result[1],
                                                   result[2]))
Ejemplo n.º 13
0
Archivo: crx.py Proyecto: eram1205/IGR
    args = parser.parse_args()

    with open(args.dataset_path) as f:
        dataset = json.load(f)

    # Symmetrize DB
    samples_dir = os.path.join(os.path.dirname(args.dataset_path), 'cases')
    symmetric_samples = []
    print('Symmetrizing DB')
    for sample in tqdm(dataset['database']['samples']):
        symmetric_sample = {}

        # Create symmetric sample directory
        symmetric_sample['case_identifier'] = sample['case_identifier'] + '_sym'
        sample_dir = os.path.join(samples_dir, symmetric_sample['case_identifier'])
        utils.mkdir_ifnotexists(sample_dir)

        # Symmetrize mesh
        mesh = trimesh.load_mesh(sample['mesh'])
        mesh_sym = symmetrize_mesh(mesh)
        symmetric_sample['mesh'] = os.path.join(sample_dir, 'mesh.obj')
        symmetric_sample['mesh_bounds'] = mesh_sym.bounds.tolist()
        mesh_sym.export(symmetric_sample['mesh'])

        # Symmetrize yaw angles
        symmetric_sample['yaw_angles'] = (-np.array(sample['yaw_angles'])).tolist()

        symmetric_samples.append(symmetric_sample)

    dataset['database']['samples'] += symmetric_samples
Ejemplo n.º 14
0
    def __init__(self,**kwargs):

        if (type(kwargs['conf']) == str):
            self.conf = ConfigFactory.parse_file(kwargs['conf'])
            self.conf_filename = kwargs['conf']
        else:
            self.conf = kwargs['conf']
        self.batch_size = kwargs['batch_size']
        self.nepochs = kwargs['nepochs']
        self.expnameraw = self.conf.get_string('train.expname')
        self.expname = self.conf.get_string('train.expname') +  kwargs['expname']

        if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
            if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)):
                timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue']

        self.adjust_lr = self.conf.get_bool('train.adjust_lr')
        self.GPU_INDEX = kwargs['gpu_index']
        self.exps_folder_name = kwargs['exps_folder_name']

        utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name))

        self.expdir = os.path.join('../', self.exps_folder_name, self.expname)
        utils.mkdir_ifnotexists(self.expdir)
        self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
        utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp))
        log_dir = os.path.join(self.expdir, self.timestamp, 'log')
        self.log_dir = log_dir
        utils.mkdir_ifnotexists(log_dir)
        utils.configure_logging(kwargs['debug'],kwargs['quiet'],os.path.join(self.log_dir,'log.txt'))

        self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"

        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path,self.model_params_subdir))
        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))

        if (not self.GPU_INDEX == 'all'):
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        # Backup code
        self.code_path = os.path.join(self.expdir, self.timestamp, 'code')
        utils.mkdir_ifnotexists(self.code_path)
        for folder in ['training','preprocess','utils','model','datasets','confs']:
            utils.mkdir_ifnotexists(os.path.join(self.code_path, folder))
            os.system("""cp -r ./{0}/* "{1}" """.format(folder,os.path.join(self.code_path, folder)))

        os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.code_path, 'confs/runconf.conf')))

        logging.info('shell command : {0}'.format(' '.join(sys.argv)))

        if (self.conf.get_string('train.data_split') == 'none'):
            self.ds = utils.get_class(self.conf.get_string('train.dataset'))(split=None, dataset_path=self.conf.get_string('train.dataset_path'), dist_file_name=None)
        else:
            train_split_file = './confs/splits/{0}'.format(self.conf.get_string('train.data_split'))

            with open(train_split_file, "r") as f:
                train_split = json.load(f)

            self.ds = utils.get_class(self.conf.get_string('train.dataset'))(split=train_split,
                                                                             dataset_path=self.conf.get_string('train.dataset_path'),
                                                                             dist_file_name=self.conf.get_string('train.dist_file_name'))

        self.dataloader = torch.utils.data.DataLoader(self.ds,
                                                      batch_size=self.batch_size,
                                                      shuffle=True,
                                                      num_workers=kwargs['workers'],drop_last=True,pin_memory=True)
        self.eval_dataloader = torch.utils.data.DataLoader(self.ds,
                                                           batch_size=1,
                                                           shuffle=True,
                                                           num_workers=0, drop_last=True)

        self.latent_size = self.conf.get_int('train.latent_size')

        self.network = utils.get_class(self.conf.get_string('train.network_class'))(conf=self.conf.get_config('network'),
                                                                                    latent_size=self.latent_size)
        if kwargs['parallel']:
            self.network = torch.nn.DataParallel(self.network)

        if torch.cuda.is_available():
            self.network.cuda()

        self.parallel = kwargs['parallel']
        self.loss = utils.get_class(self.conf.get_string('network.loss.loss_type'))(**self.conf.get_config('network.loss.properties'))
        self.lr_schedules = BaseTrainRunner.get_learning_rate_schedules(self.conf.get_list('train.learning_rate_schedule'))

        self.optimizer = torch.optim.Adam(
        [
            {
                "params": self.network.parameters(),
                "lr": self.lr_schedules[0].get_learning_rate(0),
            }
        ])

        self.start_epoch = 0
        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints')

            saved_model_state = torch.load(os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
            self.network.load_state_dict(saved_model_state["model_state_dict"])

            data = torch.load(os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])
            self.start_epoch = saved_model_state['epoch']