示例#1
0
文件: dfaustdataset.py 项目: lepy/IGR
    def __init__(self, dataset_path, split, points_batch=16384, d_in=3, with_gt=False, with_normals=False):

        base_dir = os.path.abspath(dataset_path)
        self.npyfiles_mnfld = get_instance_filenames(base_dir, split)
        self.points_batch = points_batch
        self.with_normals = with_normals
        self.d_in = d_in

        if with_gt:
            self.scans_files = get_instance_filenames(utils.concat_home_dir('datasets/dfaust/scans'), split, '','ply')
            self.scripts_files = get_instance_filenames(utils.concat_home_dir('datasets/dfaust/scripts'), split, '','obj')
            self.shapenames = [x.split('/')[-1].split('.ply')[0] for x in self.scans_files]
示例#2
0
    def __init__(self, split, dataset_path, dist_file_name, with_gt=False):
        base_dir = dataset_path
        self.npyfiles_mnfld = self.get_instance_filenames(base_dir, split)
        self.npyfiles_dist = self.get_instance_filenames(
            base_dir, split, dist_file_name)

        if (with_gt):
            # Used only for evaluation
            self.normalization_files = self.get_instance_filenames(
                base_dir, split, '_normalization')
            self.gt_files = self.get_instance_filenames(
                utils.concat_home_dir('datasets/dfaust/scripts'), split, '',
                'obj')
            self.scans_files = self.get_instance_filenames(
                utils.concat_home_dir('datasets/dfaust/scans'), split, '',
                'ply')
            self.shapenames = [
                x.split('/')[-1].split('.obj')[0] for x in self.gt_files
            ]
示例#3
0
    def __init__(self, **kwargs):

        self.home_dir = os.path.abspath(os.pardir)

        # config setting

        if type(kwargs['conf']) == str:
            self.conf_filename = './reconstruction/' + kwargs['conf']
            self.conf = ConfigFactory.parse_file(self.conf_filename)
        else:
            self.conf = kwargs['conf']

        self.expname = kwargs['expname']

        # GPU settings

        self.GPU_INDEX = kwargs['gpu_index']

        if not self.GPU_INDEX == 'ignore':
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        self.num_of_gpus = torch.cuda.device_count()

        self.eval = kwargs['eval']

        # settings for loading an existing experiment

        if (kwargs['is_continue'] or self.eval) and kwargs['timestamp'] == 'latest':
            if os.path.exists(os.path.join(self.home_dir, 'exps', self.expname)):
                timestamps = os.listdir(os.path.join(self.home_dir, 'exps', self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue'] or self.eval

        self.exps_folder_name = 'exps'

        utils.mkdir_ifnotexists(utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name)))

        self.input_file = self.conf.get_string('train.input_path')
        self.data = utils.load_point_cloud_by_file_extension(self.input_file)

        sigma_set = []
        ptree = cKDTree(self.data)

        for p in np.array_split(self.data, 100, axis=0):
            d = ptree.query(p, 50 + 1)
            sigma_set.append(d[0][:, -1])

        sigmas = np.concatenate(sigma_set)
        self.local_sigma = torch.from_numpy(sigmas).float().cuda()

        self.expdir = utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name, self.expname))
        utils.mkdir_ifnotexists(self.expdir)

        if is_continue:
            self.timestamp = timestamp
        else:
            self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())

        self.cur_exp_dir = os.path.join(self.expdir, self.timestamp)
        utils.mkdir_ifnotexists(self.cur_exp_dir)

        self.plots_dir = os.path.join(self.cur_exp_dir, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"

        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir))
        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))

        self.nepochs = kwargs['nepochs']

        self.points_batch = kwargs['points_batch']

        self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma')
        self.sampler = Sampler.get_sampler(self.conf.get_string('network.sampler.sampler_type'))(self.global_sigma,
                                                                                                 self.local_sigma)
        self.grad_lambda = self.conf.get_float('network.loss.lambda')
        self.normals_lambda = self.conf.get_float('network.loss.normals_lambda')

        self.with_normals = self.normals_lambda > 0

        self.d_in = self.conf.get_int('train.d_in')

        self.network = utils.get_class(self.conf.get_string('train.network_class'))(d_in=self.d_in,
                                                                                    **self.conf.get_config(
                                                                                        'network.inputs'))

        if torch.cuda.is_available():
            self.network.cuda()

        self.lr_schedules = self.get_learning_rate_schedules(self.conf.get_list('train.learning_rate_schedule'))
        self.weight_decay = self.conf.get_float('train.weight_decay')

        self.startepoch = 0

        self.optimizer = torch.optim.Adam(
            [
                {
                    "params": self.network.parameters(),
                    "lr": self.lr_schedules[0].get_learning_rate(0),
                    "weight_decay": self.weight_decay
                },
            ])

        # if continue load checkpoints

        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints')

            saved_model_state = torch.load(
                os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
            self.network.load_state_dict(saved_model_state["model_state_dict"])

            data = torch.load(
                os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])
            self.startepoch = saved_model_state['epoch']
示例#4
0
文件: train.py 项目: eram1205/IGR
    def __init__(self, **kwargs):

        # config setting

        self.home_dir = os.path.abspath(os.pardir)

        if type(kwargs['conf']) == str:
            self.conf_filename = os.path.abspath(kwargs['conf'])
            self.conf = ConfigFactory.parse_file(self.conf_filename)
        else:
            self.conf = kwargs['conf']

        self.expname = kwargs['expname']

        # GPU settings

        self.GPU_INDEX = kwargs['gpu_index']

        if not self.GPU_INDEX == 'ignore':
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        self.num_of_gpus = torch.cuda.device_count()

        # settings for loading an existing experiment

        if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
            if os.path.exists(os.path.join(self.home_dir, 'exps',
                                           self.expname)):
                timestamps = os.listdir(
                    os.path.join(self.home_dir, 'exps', self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue']

        self.exps_folder_name = 'exps'

        utils.mkdir_ifnotexists(
            utils.concat_home_dir(
                os.path.join(self.home_dir, self.exps_folder_name)))

        self.expdir = utils.concat_home_dir(
            os.path.join(self.home_dir, self.exps_folder_name, self.expname))
        utils.mkdir_ifnotexists(self.expdir)

        if is_continue:
            self.timestamp = timestamp
        else:
            self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())

        self.cur_exp_dir = self.timestamp
        utils.mkdir_ifnotexists(os.path.join(self.expdir, self.cur_exp_dir))

        self.plots_dir = os.path.join(self.expdir, self.cur_exp_dir, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir,
                                             'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir,
                                             'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"
        self.latent_codes_subdir = "LatentCodes"

        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.model_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.latent_codes_subdir))

        self.nepochs = kwargs['nepochs']

        self.batch_size = kwargs['batch_size']

        if self.num_of_gpus > 0:
            self.batch_size *= self.num_of_gpus

        self.parallel = self.num_of_gpus > 1

        self.global_sigma = self.conf.get_float(
            'network.sampler.properties.global_sigma')
        self.local_sigma = self.conf.get_float(
            'network.sampler.properties.local_sigma')
        self.sampler = Sampler.get_sampler(
            self.conf.get_string('network.sampler.sampler_type'))(
                self.global_sigma, self.local_sigma)

        train_split_file = os.path.abspath(kwargs['split_file'])
        print(f'Loading split file {train_split_file}')
        with open(train_split_file, "r") as f:
            train_split = json.load(f)
        print(f'Size of the split: {len(train_split)} samples')

        self.d_in = self.conf.get_int('train.d_in')

        # latent preprocessing

        self.latent_size = self.conf.get_int('train.latent_size')

        self.latent_lambda = self.conf.get_float('network.loss.latent_lambda')
        self.grad_lambda = self.conf.get_float('network.loss.lambda')
        self.normals_lambda = self.conf.get_float(
            'network.loss.normals_lambda')

        self.with_normals = self.normals_lambda > 0

        self.ds = utils.get_class(self.conf.get_string('train.dataset'))(
            split=train_split,
            with_normals=self.with_normals,
            dataset_path=self.conf.get_string('train.dataset_path'),
            points_batch=kwargs['points_batch'],
        )

        self.num_scenes = len(self.ds)

        self.train_dataloader = torch.utils.data.DataLoader(
            self.ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=kwargs['threads'],
            drop_last=True,
            pin_memory=True)
        self.eval_dataloader = torch.utils.data.DataLoader(self.ds,
                                                           batch_size=1,
                                                           shuffle=False,
                                                           num_workers=0,
                                                           drop_last=True)

        self.network = utils.get_class(
            self.conf.get_string('train.network_class'))(
                d_in=(self.d_in + self.latent_size),
                **self.conf.get_config('network.inputs'))

        if self.parallel:
            self.network = torch.nn.DataParallel(self.network)

        if torch.cuda.is_available():
            self.network.cuda()

        self.lr_schedules = self.get_learning_rate_schedules(
            self.conf.get_list('train.learning_rate_schedule'))
        self.weight_decay = self.conf.get_float('train.weight_decay')

        # optimizer and latent settings

        self.startepoch = 0

        self.lat_vecs = utils.to_cuda(
            torch.zeros(self.num_scenes, self.latent_size))
        self.lat_vecs.requires_grad_()

        self.optimizer = torch.optim.Adam([
            {
                "params": self.network.parameters(),
                "lr": self.lr_schedules[0].get_learning_rate(0),
                "weight_decay": self.weight_decay
            },
            {
                "params": self.lat_vecs,
                "lr": self.lr_schedules[1].get_learning_rate(0)
            },
        ])

        # if continue load checkpoints

        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp,
                                             'checkpoints')

            data = torch.load(
                os.path.join(old_checkpnts_dir, self.latent_codes_subdir,
                             str(kwargs['checkpoint']) + '.pth'))

            self.lat_vecs = utils.to_cuda(data["latent_codes"])

            saved_model_state = torch.load(
                os.path.join(old_checkpnts_dir, 'ModelParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.network.load_state_dict(saved_model_state["model_state_dict"])

            data = torch.load(
                os.path.join(old_checkpnts_dir, 'OptimizerParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])
            self.startepoch = saved_model_state['epoch']