Ejemplo n.º 1
0
    def __init__(self, **kwargs):

        # config setting

        self.home_dir = os.path.abspath(os.pardir)

        if type(kwargs['conf']) == str:
            self.conf_filename = os.path.abspath(kwargs['conf'])
            self.conf = ConfigFactory.parse_file(self.conf_filename)
        else:
            self.conf = kwargs['conf']

        self.expname = kwargs['expname']

        # GPU settings

        self.GPU_INDEX = kwargs['gpu_index']

        if not self.GPU_INDEX == 'ignore':
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        self.num_of_gpus = torch.cuda.device_count()

        # settings for loading an existing experiment

        if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
            if os.path.exists(os.path.join(self.home_dir, 'exps',
                                           self.expname)):
                timestamps = os.listdir(
                    os.path.join(self.home_dir, 'exps', self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue']

        self.exps_folder_name = 'exps'

        utils.mkdir_ifnotexists(
            utils.concat_home_dir(
                os.path.join(self.home_dir, self.exps_folder_name)))

        self.expdir = utils.concat_home_dir(
            os.path.join(self.home_dir, self.exps_folder_name, self.expname))
        utils.mkdir_ifnotexists(self.expdir)

        if is_continue:
            self.timestamp = timestamp
        else:
            self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())

        self.cur_exp_dir = self.timestamp
        utils.mkdir_ifnotexists(os.path.join(self.expdir, self.cur_exp_dir))

        self.plots_dir = os.path.join(self.expdir, self.cur_exp_dir, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir,
                                             'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.checkpoints_path = os.path.join(self.expdir, self.cur_exp_dir,
                                             'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"
        self.latent_codes_subdir = "LatentCodes"

        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.model_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
        utils.mkdir_ifnotexists(
            os.path.join(self.checkpoints_path, self.latent_codes_subdir))

        self.nepochs = kwargs['nepochs']

        self.batch_size = kwargs['batch_size']

        if self.num_of_gpus > 0:
            self.batch_size *= self.num_of_gpus

        self.parallel = self.num_of_gpus > 1

        self.global_sigma = self.conf.get_float(
            'network.sampler.properties.global_sigma')
        self.local_sigma = self.conf.get_float(
            'network.sampler.properties.local_sigma')
        self.sampler = Sampler.get_sampler(
            self.conf.get_string('network.sampler.sampler_type'))(
                self.global_sigma, self.local_sigma)

        train_split_file = os.path.abspath(kwargs['split_file'])
        print(f'Loading split file {train_split_file}')
        with open(train_split_file, "r") as f:
            train_split = json.load(f)
        print(f'Size of the split: {len(train_split)} samples')

        self.d_in = self.conf.get_int('train.d_in')

        # latent preprocessing

        self.latent_size = self.conf.get_int('train.latent_size')

        self.latent_lambda = self.conf.get_float('network.loss.latent_lambda')
        self.grad_lambda = self.conf.get_float('network.loss.lambda')
        self.normals_lambda = self.conf.get_float(
            'network.loss.normals_lambda')

        self.with_normals = self.normals_lambda > 0

        self.ds = utils.get_class(self.conf.get_string('train.dataset'))(
            split=train_split,
            with_normals=self.with_normals,
            dataset_path=self.conf.get_string('train.dataset_path'),
            points_batch=kwargs['points_batch'],
        )

        self.num_scenes = len(self.ds)

        self.train_dataloader = torch.utils.data.DataLoader(
            self.ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=kwargs['threads'],
            drop_last=True,
            pin_memory=True)
        self.eval_dataloader = torch.utils.data.DataLoader(self.ds,
                                                           batch_size=1,
                                                           shuffle=False,
                                                           num_workers=0,
                                                           drop_last=True)

        self.network = utils.get_class(
            self.conf.get_string('train.network_class'))(
                d_in=(self.d_in + self.latent_size),
                **self.conf.get_config('network.inputs'))

        if self.parallel:
            self.network = torch.nn.DataParallel(self.network)

        if torch.cuda.is_available():
            self.network.cuda()

        self.lr_schedules = self.get_learning_rate_schedules(
            self.conf.get_list('train.learning_rate_schedule'))
        self.weight_decay = self.conf.get_float('train.weight_decay')

        # optimizer and latent settings

        self.startepoch = 0

        self.lat_vecs = utils.to_cuda(
            torch.zeros(self.num_scenes, self.latent_size))
        self.lat_vecs.requires_grad_()

        self.optimizer = torch.optim.Adam([
            {
                "params": self.network.parameters(),
                "lr": self.lr_schedules[0].get_learning_rate(0),
                "weight_decay": self.weight_decay
            },
            {
                "params": self.lat_vecs,
                "lr": self.lr_schedules[1].get_learning_rate(0)
            },
        ])

        # if continue load checkpoints

        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp,
                                             'checkpoints')

            data = torch.load(
                os.path.join(old_checkpnts_dir, self.latent_codes_subdir,
                             str(kwargs['checkpoint']) + '.pth'))

            self.lat_vecs = utils.to_cuda(data["latent_codes"])

            saved_model_state = torch.load(
                os.path.join(old_checkpnts_dir, 'ModelParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.network.load_state_dict(saved_model_state["model_state_dict"])

            data = torch.load(
                os.path.join(old_checkpnts_dir, 'OptimizerParameters',
                             str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])
            self.startepoch = saved_model_state['epoch']
Ejemplo n.º 2
0
    def __init__(self, **kwargs):

        self.home_dir = os.path.abspath(os.pardir)

        # config setting

        if type(kwargs['conf']) == str:
            self.conf_filename = './reconstruction/' + kwargs['conf']
            self.conf = ConfigFactory.parse_file(self.conf_filename)
        else:
            self.conf = kwargs['conf']

        self.expname = kwargs['expname']

        # GPU settings

        self.GPU_INDEX = kwargs['gpu_index']

        if not self.GPU_INDEX == 'ignore':
            os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)

        self.num_of_gpus = torch.cuda.device_count()

        self.eval = kwargs['eval']

        # settings for loading an existing experiment

        if (kwargs['is_continue'] or self.eval) and kwargs['timestamp'] == 'latest':
            if os.path.exists(os.path.join(self.home_dir, 'exps', self.expname)):
                timestamps = os.listdir(os.path.join(self.home_dir, 'exps', self.expname))
                if (len(timestamps)) == 0:
                    is_continue = False
                    timestamp = None
                else:
                    timestamp = sorted(timestamps)[-1]
                    is_continue = True
            else:
                is_continue = False
                timestamp = None
        else:
            timestamp = kwargs['timestamp']
            is_continue = kwargs['is_continue'] or self.eval

        self.exps_folder_name = 'exps'

        utils.mkdir_ifnotexists(utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name)))

        self.input_file = self.conf.get_string('train.input_path')
        self.data = utils.load_point_cloud_by_file_extension(self.input_file)

        sigma_set = []
        ptree = cKDTree(self.data)

        for p in np.array_split(self.data, 100, axis=0):
            d = ptree.query(p, 50 + 1)
            sigma_set.append(d[0][:, -1])

        sigmas = np.concatenate(sigma_set)
        self.local_sigma = torch.from_numpy(sigmas).float().cuda()

        self.expdir = utils.concat_home_dir(os.path.join(self.home_dir, self.exps_folder_name, self.expname))
        utils.mkdir_ifnotexists(self.expdir)

        if is_continue:
            self.timestamp = timestamp
        else:
            self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())

        self.cur_exp_dir = os.path.join(self.expdir, self.timestamp)
        utils.mkdir_ifnotexists(self.cur_exp_dir)

        self.plots_dir = os.path.join(self.cur_exp_dir, 'plots')
        utils.mkdir_ifnotexists(self.plots_dir)

        self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.checkpoints_path = os.path.join(self.cur_exp_dir, 'checkpoints')
        utils.mkdir_ifnotexists(self.checkpoints_path)

        self.model_params_subdir = "ModelParameters"
        self.optimizer_params_subdir = "OptimizerParameters"

        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir))
        utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))

        self.nepochs = kwargs['nepochs']

        self.points_batch = kwargs['points_batch']

        self.global_sigma = self.conf.get_float('network.sampler.properties.global_sigma')
        self.sampler = Sampler.get_sampler(self.conf.get_string('network.sampler.sampler_type'))(self.global_sigma,
                                                                                                 self.local_sigma)
        self.grad_lambda = self.conf.get_float('network.loss.lambda')
        self.normals_lambda = self.conf.get_float('network.loss.normals_lambda')

        self.with_normals = self.normals_lambda > 0

        self.d_in = self.conf.get_int('train.d_in')

        self.network = utils.get_class(self.conf.get_string('train.network_class'))(d_in=self.d_in,
                                                                                    **self.conf.get_config(
                                                                                        'network.inputs'))

        if torch.cuda.is_available():
            self.network.cuda()

        self.lr_schedules = self.get_learning_rate_schedules(self.conf.get_list('train.learning_rate_schedule'))
        self.weight_decay = self.conf.get_float('train.weight_decay')

        self.startepoch = 0

        self.optimizer = torch.optim.Adam(
            [
                {
                    "params": self.network.parameters(),
                    "lr": self.lr_schedules[0].get_learning_rate(0),
                    "weight_decay": self.weight_decay
                },
            ])

        # if continue load checkpoints

        if is_continue:
            old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints')

            saved_model_state = torch.load(
                os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
            self.network.load_state_dict(saved_model_state["model_state_dict"])

            data = torch.load(
                os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
            self.optimizer.load_state_dict(data["optimizer_state_dict"])
            self.startepoch = saved_model_state['epoch']
Ejemplo n.º 3
0
def optimize_latent(points,
                    normals,
                    conf,
                    num_of_iterations,
                    network,
                    lr=1.0e-2):

    latent_size = conf.get_int('train.latent_size')
    global_sigma = conf.get_float('network.sampler.properties.global_sigma')
    local_sigma = conf.get_float('network.sampler.properties.local_sigma')
    sampler = Sampler.get_sampler(
        conf.get_string('network.sampler.sampler_type'))(global_sigma,
                                                         local_sigma)

    latent_lambda = conf.get_float('network.loss.latent_lambda')

    normals_lambda = conf.get_float('network.loss.normals_lambda')

    grad_lambda = conf.get_float('network.loss.lambda')

    num_of_points, dim = points.shape

    latent = torch.ones(latent_size).normal_(0, 1 / latent_size).cuda()
    # latent = torch.zeros(latent_size).cuda()

    latent.requires_grad = True

    optimizer = torch.optim.Adam([latent], lr=lr)

    for i in range(num_of_iterations):

        sample = sampler.get_points(points.unsqueeze(0)).squeeze()

        latent_all = latent.expand(num_of_points, -1)
        surface_pnts = torch.cat([latent_all, points], dim=1)

        sample_latent_all = latent.expand(sample.shape[0], -1)
        nonsurface_pnts = torch.cat([sample_latent_all, sample], dim=1)

        surface_pnts.requires_grad_()
        nonsurface_pnts.requires_grad_()

        surface_pred = network(surface_pnts)
        nonsurface_pred = network(nonsurface_pnts)

        surface_grad = gradient(surface_pnts, surface_pred)
        nonsurface_grad = gradient(nonsurface_pnts, nonsurface_pred)

        surface_loss = torch.abs(surface_pred).mean()
        grad_loss = torch.mean((nonsurface_grad.norm(2, dim=-1) - 1).pow(2))
        normals_loss = ((surface_grad - normals).abs()).norm(2, dim=1).mean()
        latent_loss = latent.abs().mean()
        loss = surface_loss + latent_lambda * latent_loss + normals_lambda * normals_loss + grad_lambda * grad_loss

        adjust_learning_rate(lr, optimizer, i)

        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        print('latent loss iter {0}:{1}'.format(i, loss.item()))

    return latent.unsqueeze(0)