def __init__(self, tree, step_size : float=1e-3, background_brightness : float=1.0, ndc : NDCConfig=None, min_comp=0, max_comp=-1 ): """ Construct volume renderer associated with given N^3 tree. :param tree: N3Tree instance for rendering :param step_size: float step size eps, added to each DDA step :param background_brightness: float background brightness, 1.0 = white :param ndc: NDCConfig, NDC coordinate configuration, namedtuple(width, height, focal). None = no NDC, use usual coordinates :param min_comp: minimum SH/SG component to render :param max_comp: maximum SH/SG component to render, -1=last """ super().__init__() self.tree = tree self.step_size = step_size self.background_brightness = background_brightness self.ndc_config = ndc self.min_comp = min_comp self.max_comp = max_comp if isinstance(tree.data_format, DataFormat): self._data_format = None else: warn("Legacy N3Tree (pre 0.2.18) without data_format, auto-infering SH deg") # Auto SH deg ddim = tree.data_dim if ddim == 4: self._data_format = DataFormat("") else: self._data_format = DataFormat(f"SH{(ddim - 1) // 3}") self.tree._weight_accum = None
def expand(self, data_format, data_dim=None, remap=None): """ Modify the size of the data stored at the octree leaves. :param data_format: new data format, RGBA | SH# | SG# | ASG# :param data_dim: data dimension; inferred from data_format by default only needed if data_format is RGBA. :param remap: mapping of old data to new data. For each leaf, we will do :code:`new_data[remap] = old_data`. By default, this will be inferred automatically (maps basis functions in the correct way). """ assert isinstance(data_format, str), "Please specify valid data format" old_data_format = self.data_format old_data_dim = self.data_dim self.data_format = DataFormat( data_format) if data_format is not None else None self.data_dim = data_dim self._maybe_auto_data_dim() del data_dim assert self.data_dim >= old_data_dim, "Cannot expand to something smaller" if remap is None: sigma_arr = torch.tensor([self.data_dim - 1]) if old_data_format is None or self.data_format.format == DataFormat.RGBA: remap = torch.cat([torch.arange(old_data_dim - 1), sigma_arr]) else: assert self.data_format.basis_dim >= 1, \ "Please manually specify data_dim for expand()" old_basis_dim = old_data_format.basis_dim if old_basis_dim < 0: old_basis_dim = 1 shift = self.data_format.basis_dim arr = torch.arange(old_basis_dim) remap = torch.cat( [arr, shift + arr, 2 * shift + arr, sigma_arr]) may_oom = self.data.numel() > 8e9 if may_oom: # Potential OOM prevention hack self.data = nn.Parameter(self.data.cpu()) tmp_data = torch.zeros((*self.data.data.shape[:-1], self.data_dim), device=self.data.device) tmp_data[..., remap] = self.data.data if may_oom: self.data = nn.Parameter(tmp_data.to(device=self.child.device)) else: self.data = nn.Parameter(tmp_data)
def load(cls, path, device='cpu', dtype=torch.float32, map_location=None): """ Load from npz file :param path: npz path :param device: str device to put data :param dtype: str torch.float32 (default) | torch.float64 :param map_location: str DEPRECATED old name for device """ if map_location is not None: warn('map_location has been renamed to device and may be removed') device = map_location assert dtype == torch.float32 or dtype == torch.float64, 'Unsupported dtype' tree = cls(dtype=dtype, device=device) z = np.load(path) tree.data_dim = int(z["data_dim"]) tree.child = torch.from_numpy(z["child"]).to(device) tree.N = tree.child.shape[-1] tree.parent_depth = torch.from_numpy(z["parent_depth"]).to(device) tree._n_internal.fill_(z["n_internal"].item()) if "invradius3" in z.files: tree.invradius = torch.from_numpy(z["invradius3"].astype( np.float32)).to(device) else: tree.invradius.fill_(z["invradius"].item()) tree.offset = torch.from_numpy(z["offset"].astype( np.float32)).to(device) tree.depth_limit = int(z["depth_limit"]) tree.geom_resize_fact = float(z["geom_resize_fact"]) tree.data.data = torch.from_numpy(z["data"].astype( np.float32)).to(device) if 'n_free' in z.files: tree._n_free.fill_(z["n_free"].item()) else: tree._n_free.zero_() tree.data_format = DataFormat(z['data_format'].item()) if \ 'data_format' in z.files else None tree.extra_data = torch.from_numpy(z['extra_data']).to(device) if \ 'extra_data' in z.files else None return tree
def load(cls, path, map_location='cpu'): """ Load from npz file :param path: npz path :param map_location: device to put data """ tree = cls(map_location=map_location) z = np.load(path) tree.data_dim = int(z["data_dim"]) tree.child = torch.from_numpy(z["child"]).to(map_location) tree.N = tree.child.shape[-1] tree.parent_depth = torch.from_numpy( z["parent_depth"]).to(map_location) tree._n_internal.fill_(z["n_internal"].item()) if "invradius3" in z.files: tree.invradius = torch.from_numpy(z["invradius3"].astype( np.float32)).to(map_location) else: tree.invradius.fill_(z["invradius"].item()) tree.offset = torch.from_numpy(z["offset"].astype( np.float32)).to(map_location) tree.depth_limit = int(z["depth_limit"]) tree.geom_resize_fact = float(z["geom_resize_fact"]) tree.data.data = torch.from_numpy(z["data"].astype( np.float32)).to(map_location) if 'n_free' in z.files: tree._n_free.fill_(z["n_free"].item()) else: tree._n_free.zero_() tree.data_format = DataFormat(z['data_format'].item()) if \ 'data_format' in z.files else None tree.extra_data = torch.from_numpy(z['extra_data']).to(map_location) if \ 'extra_data' in z.files else None return tree
def __init__(self, N=2, data_dim=None, depth_limit=10, init_reserve=1, init_refine=0, geom_resize_fact=1.0, radius=0.5, center=[0.5, 0.5, 0.5], data_format="RGBA", extra_data=None, map_location="cpu"): """ Construct N^3 Tree :param N: int branching factor N :param data_dim: int size of data stored at each leaf (NEW in 0.2.28: optional if data_format other than RGBA is given). If data_format = "RGBA" or empty, this defaults to 4. :param depth_limit: int maximum depth of tree to stop branching/refining Note that the root is at depth -1. Size N^[-10] leaves (1/1024 for octree) for example are depth 9. :code:`max_depth` applies to the same depth values. :param init_reserve: int amount of nodes to reserve initially :param init_refine: int number of times to refine entire tree initially inital resolution will be [N^(init_refine + 1)]^3. initial max_depth will be init_refine. :param geom_resize_fact: float geometric resizing factor :param radius: float or list, 1/2 side length of cube (possibly in each dim) :param center: list center of space :param data_format: a string to indicate the data format. RGBA | SH# | SG# | ASG# :param extra_data: extra data to include with tree :param map_location: str device to put data """ super().__init__() assert N >= 2 assert depth_limit >= 0 self.N: int = N self.data_format = DataFormat( data_format) if data_format is not None else None self.data_dim: int = data_dim self._maybe_auto_data_dim() del data_dim if init_refine > 0: for i in range(1, init_refine + 1): init_reserve += (N**i)**3 self.register_parameter( "data", nn.Parameter( torch.zeros(init_reserve, N, N, N, self.data_dim, device=map_location))) self.register_buffer( "child", torch.zeros(init_reserve, N, N, N, dtype=torch.int32, device=map_location)) self.register_buffer( "parent_depth", torch.zeros(init_reserve, 2, dtype=torch.int32, device=map_location)) self.register_buffer("_n_internal", torch.tensor(1, device=map_location)) self.register_buffer("_n_free", torch.tensor(0, device=map_location)) if isinstance(radius, float) or isinstance(radius, int): radius = [radius] * 3 radius = torch.tensor(radius, dtype=torch.float32, device=map_location) center = torch.tensor(center, dtype=torch.float32, device=map_location) self.register_buffer("invradius", 0.5 / radius) self.register_buffer("offset", 0.5 * (1.0 - center / radius)) self.depth_limit = depth_limit self.geom_resize_fact = geom_resize_fact if extra_data is not None: assert isinstance(extra_data, torch.Tensor) self.register_buffer("extra_data", extra_data.to(device=map_location)) else: self.extra_data = None self._ver = 0 self._invalidate() self._lock_tree_structure = False self._weight_accum = None self.refine(repeats=init_refine)
def __init__( self, tree, step_size: float = 1e-3, background_brightness: float = 1.0, ndc: NDCConfig = None, min_comp: int = 0, max_comp: int = -1, density_softplus: bool = False, rgb_padding: float = 0.0, ): """ Construct volume renderer associated with given N^3 tree. The renderer traces rays with origins/dirs within the octree boundaries, detection ray-voxel intersections. The color and density within each voxel is assumed constant, and no interpolation is performed. For each intersection point, it queries the tree, assuming the last data dimension is density (sigma) and the rest of the dimensions are color, formatted according to tree.data_format. It then applies SH/SG/ASG basis functions, if any, according to viewdirs. Sigmoid will be applied to these colors to normalize them, and optionally a shifted softplus is applied to the density. :param tree: N3Tree instance for rendering :param step_size: float step size eps, added to each voxel aabb intersection step :param background_brightness: float background brightness, 1.0 = white :param ndc: NDCConfig, NDC coordinate configuration, namedtuple(width, height, focal). None = no NDC, use usual coordinates :param min_comp: minimum SH/SG component to render. :param max_comp: maximum SH/SG component to render, -1=last. Set :code:`min_comp = max_comp` to render a particular component. Default means all. :param density_softplus: if true, applies :math:`\\log(1 + \\exp(sigma - 1))`. **Mind the shift -1!** (from mip-NeRF). Please note softplus will NOT be compatible with volrend, please pre-apply it . :param rgb_padding: to avoid oversaturating the sigmoid, applies :code:`* (1 + 2 * rgb_padding) - rgb_padding` to colors after sigmoid (from mip-NeRF). Please note the padding will NOT be compatible with volrend, although most likely the effect is very small. 0.001 is a reasonable value to try. """ super().__init__() self.tree = tree self.step_size = step_size self.background_brightness = background_brightness self.ndc_config = ndc self.min_comp = min_comp self.max_comp = max_comp self.density_softplus = density_softplus self.rgb_padding = rgb_padding if isinstance(tree.data_format, DataFormat): self._data_format = None else: warn( "Legacy N3Tree (pre 0.2.18) without data_format, auto-infering SH deg" ) # Auto SH deg ddim = tree.data_dim if ddim == 4: self._data_format = DataFormat("") else: self._data_format = DataFormat(f"SH{(ddim - 1) // 3}") self.tree._weight_accum = None