예제 #1
0
    def __init__(self, tree,
            step_size : float=1e-3,
            background_brightness : float=1.0,
            ndc : NDCConfig=None,
            min_comp=0,
            max_comp=-1
        ):
        """
        Construct volume renderer associated with given N^3 tree.

        :param tree: N3Tree instance for rendering
        :param step_size: float step size eps, added to each DDA step
        :param background_brightness: float background brightness, 1.0 = white
        :param ndc: NDCConfig, NDC coordinate configuration,
                    namedtuple(width, height, focal).
                    None = no NDC, use usual coordinates
        :param min_comp: minimum SH/SG component to render
        :param max_comp: maximum SH/SG component to render, -1=last

        """
        super().__init__()
        self.tree = tree
        self.step_size = step_size
        self.background_brightness = background_brightness
        self.ndc_config = ndc
        self.min_comp = min_comp
        self.max_comp = max_comp
        if isinstance(tree.data_format, DataFormat):
            self._data_format = None
        else:
            warn("Legacy N3Tree (pre 0.2.18) without data_format, auto-infering SH deg")
            # Auto SH deg
            ddim = tree.data_dim
            if ddim == 4:
                self._data_format = DataFormat("")
            else:
                self._data_format = DataFormat(f"SH{(ddim - 1) // 3}")
        self.tree._weight_accum = None
예제 #2
0
    def expand(self, data_format, data_dim=None, remap=None):
        """
        Modify the size of the data stored at the octree leaves.

        :param data_format: new data format, RGBA | SH# | SG# | ASG#
        :param data_dim: data dimension; inferred from data_format by default
                only needed if data_format is RGBA.
        :param remap: mapping of old data to new data. For each leaf, we will do
                :code:`new_data[remap] = old_data`. By default,
                this will be inferred automatically (maps basis functions
                in the correct way).

        """
        assert isinstance(data_format, str), "Please specify valid data format"
        old_data_format = self.data_format
        old_data_dim = self.data_dim
        self.data_format = DataFormat(
            data_format) if data_format is not None else None
        self.data_dim = data_dim
        self._maybe_auto_data_dim()
        del data_dim
        assert self.data_dim >= old_data_dim, "Cannot expand to something smaller"

        if remap is None:
            sigma_arr = torch.tensor([self.data_dim - 1])
            if old_data_format is None or self.data_format.format == DataFormat.RGBA:
                remap = torch.cat([torch.arange(old_data_dim - 1), sigma_arr])
            else:
                assert self.data_format.basis_dim >= 1, \
                       "Please manually specify data_dim for expand()"
                old_basis_dim = old_data_format.basis_dim
                if old_basis_dim < 0:
                    old_basis_dim = 1
                shift = self.data_format.basis_dim
                arr = torch.arange(old_basis_dim)
                remap = torch.cat(
                    [arr, shift + arr, 2 * shift + arr, sigma_arr])

        may_oom = self.data.numel() > 8e9
        if may_oom:
            # Potential OOM prevention hack
            self.data = nn.Parameter(self.data.cpu())
        tmp_data = torch.zeros((*self.data.data.shape[:-1], self.data_dim),
                               device=self.data.device)
        tmp_data[..., remap] = self.data.data
        if may_oom:
            self.data = nn.Parameter(tmp_data.to(device=self.child.device))
        else:
            self.data = nn.Parameter(tmp_data)
예제 #3
0
    def load(cls, path, device='cpu', dtype=torch.float32, map_location=None):
        """
        Load from npz file

        :param path: npz path
        :param device: str device to put data
        :param dtype: str torch.float32 (default) | torch.float64
        :param map_location: str DEPRECATED old name for device

        """
        if map_location is not None:
            warn('map_location has been renamed to device and may be removed')
            device = map_location
        assert dtype == torch.float32 or dtype == torch.float64, 'Unsupported dtype'
        tree = cls(dtype=dtype, device=device)
        z = np.load(path)
        tree.data_dim = int(z["data_dim"])
        tree.child = torch.from_numpy(z["child"]).to(device)
        tree.N = tree.child.shape[-1]
        tree.parent_depth = torch.from_numpy(z["parent_depth"]).to(device)
        tree._n_internal.fill_(z["n_internal"].item())
        if "invradius3" in z.files:
            tree.invradius = torch.from_numpy(z["invradius3"].astype(
                np.float32)).to(device)
        else:
            tree.invradius.fill_(z["invradius"].item())
        tree.offset = torch.from_numpy(z["offset"].astype(
            np.float32)).to(device)
        tree.depth_limit = int(z["depth_limit"])
        tree.geom_resize_fact = float(z["geom_resize_fact"])
        tree.data.data = torch.from_numpy(z["data"].astype(
            np.float32)).to(device)
        if 'n_free' in z.files:
            tree._n_free.fill_(z["n_free"].item())
        else:
            tree._n_free.zero_()
        tree.data_format = DataFormat(z['data_format'].item()) if \
                'data_format' in z.files else None
        tree.extra_data = torch.from_numpy(z['extra_data']).to(device) if \
                          'extra_data' in z.files else None
        return tree
예제 #4
0
    def load(cls, path, map_location='cpu'):
        """
        Load from npz file

        :param path: npz path
        :param map_location: device to put data

        """
        tree = cls(map_location=map_location)
        z = np.load(path)
        tree.data_dim = int(z["data_dim"])
        tree.child = torch.from_numpy(z["child"]).to(map_location)
        tree.N = tree.child.shape[-1]
        tree.parent_depth = torch.from_numpy(
            z["parent_depth"]).to(map_location)
        tree._n_internal.fill_(z["n_internal"].item())
        if "invradius3" in z.files:
            tree.invradius = torch.from_numpy(z["invradius3"].astype(
                np.float32)).to(map_location)
        else:
            tree.invradius.fill_(z["invradius"].item())
        tree.offset = torch.from_numpy(z["offset"].astype(
            np.float32)).to(map_location)
        tree.depth_limit = int(z["depth_limit"])
        tree.geom_resize_fact = float(z["geom_resize_fact"])
        tree.data.data = torch.from_numpy(z["data"].astype(
            np.float32)).to(map_location)
        if 'n_free' in z.files:
            tree._n_free.fill_(z["n_free"].item())
        else:
            tree._n_free.zero_()
        tree.data_format = DataFormat(z['data_format'].item()) if \
                'data_format' in z.files else None
        tree.extra_data = torch.from_numpy(z['extra_data']).to(map_location) if \
                          'extra_data' in z.files else None
        return tree
예제 #5
0
    def __init__(self,
                 N=2,
                 data_dim=None,
                 depth_limit=10,
                 init_reserve=1,
                 init_refine=0,
                 geom_resize_fact=1.0,
                 radius=0.5,
                 center=[0.5, 0.5, 0.5],
                 data_format="RGBA",
                 extra_data=None,
                 map_location="cpu"):
        """
        Construct N^3 Tree

        :param N: int branching factor N
        :param data_dim: int size of data stored at each leaf (NEW in 0.2.28: optional if data_format other than RGBA is given).
                        If data_format = "RGBA" or empty, this defaults to 4.
        :param depth_limit: int maximum depth  of tree to stop branching/refining
                            Note that the root is at depth -1.
                            Size N^[-10] leaves (1/1024 for octree) for example
                            are depth 9. :code:`max_depth` applies to the same
                            depth values.
        :param init_reserve: int amount of nodes to reserve initially
        :param init_refine: int number of times to refine entire tree initially
                            inital resolution will be [N^(init_refine + 1)]^3.
                            initial max_depth will be init_refine.
        :param geom_resize_fact: float geometric resizing factor
        :param radius: float or list, 1/2 side length of cube (possibly in each dim)
        :param center: list center of space
        :param data_format: a string to indicate the data format. RGBA | SH# | SG# | ASG#
        :param extra_data: extra data to include with tree
        :param map_location: str device to put data

        """
        super().__init__()
        assert N >= 2
        assert depth_limit >= 0
        self.N: int = N

        self.data_format = DataFormat(
            data_format) if data_format is not None else None
        self.data_dim: int = data_dim
        self._maybe_auto_data_dim()
        del data_dim

        if init_refine > 0:
            for i in range(1, init_refine + 1):
                init_reserve += (N**i)**3

        self.register_parameter(
            "data",
            nn.Parameter(
                torch.zeros(init_reserve,
                            N,
                            N,
                            N,
                            self.data_dim,
                            device=map_location)))
        self.register_buffer(
            "child",
            torch.zeros(init_reserve,
                        N,
                        N,
                        N,
                        dtype=torch.int32,
                        device=map_location))
        self.register_buffer(
            "parent_depth",
            torch.zeros(init_reserve,
                        2,
                        dtype=torch.int32,
                        device=map_location))

        self.register_buffer("_n_internal", torch.tensor(1,
                                                         device=map_location))
        self.register_buffer("_n_free", torch.tensor(0, device=map_location))

        if isinstance(radius, float) or isinstance(radius, int):
            radius = [radius] * 3
        radius = torch.tensor(radius, dtype=torch.float32, device=map_location)
        center = torch.tensor(center, dtype=torch.float32, device=map_location)

        self.register_buffer("invradius", 0.5 / radius)
        self.register_buffer("offset", 0.5 * (1.0 - center / radius))

        self.depth_limit = depth_limit
        self.geom_resize_fact = geom_resize_fact

        if extra_data is not None:
            assert isinstance(extra_data, torch.Tensor)
            self.register_buffer("extra_data",
                                 extra_data.to(device=map_location))
        else:
            self.extra_data = None

        self._ver = 0
        self._invalidate()
        self._lock_tree_structure = False
        self._weight_accum = None

        self.refine(repeats=init_refine)
예제 #6
0
파일: renderer.py 프로젝트: sarafridov/svox
    def __init__(
        self,
        tree,
        step_size: float = 1e-3,
        background_brightness: float = 1.0,
        ndc: NDCConfig = None,
        min_comp: int = 0,
        max_comp: int = -1,
        density_softplus: bool = False,
        rgb_padding: float = 0.0,
    ):
        """
        Construct volume renderer associated with given N^3 tree.

        The renderer traces rays with origins/dirs within the octree boundaries,
        detection ray-voxel intersections. The color and density within
        each voxel is assumed constant, and no interpolation is performed.

        For each intersection point, it queries the tree, assuming the last data dimension
        is density (sigma) and the rest of the dimensions are color,
        formatted according to tree.data_format.
        It then applies SH/SG/ASG basis functions, if any, according to viewdirs.
        Sigmoid will be applied to these colors to normalize them,
        and optionally a shifted softplus is applied to the density.

        :param tree: N3Tree instance for rendering
        :param step_size: float step size eps, added to each voxel aabb intersection step
        :param background_brightness: float background brightness, 1.0 = white
        :param ndc: NDCConfig, NDC coordinate configuration,
                    namedtuple(width, height, focal).
                    None = no NDC, use usual coordinates
        :param min_comp: minimum SH/SG component to render.
        :param max_comp: maximum SH/SG component to render, -1=last.
                         Set :code:`min_comp = max_comp` to render a particular
                         component. Default means all.
        :param density_softplus: if true, applies :math:`\\log(1 + \\exp(sigma - 1))`.
                                 **Mind the shift -1!** (from mip-NeRF).
                                 Please note softplus will NOT be compatible with volrend,
                                 please pre-apply it .
        :param rgb_padding: to avoid oversaturating the sigmoid,
                        applies :code:`* (1 + 2 * rgb_padding) - rgb_padding` to
                        colors after sigmoid (from mip-NeRF).
                        Please note the padding will NOT be compatible with volrend,
                        although most likely the effect is very small.
                        0.001 is a reasonable value to try.

        """
        super().__init__()
        self.tree = tree
        self.step_size = step_size
        self.background_brightness = background_brightness
        self.ndc_config = ndc
        self.min_comp = min_comp
        self.max_comp = max_comp
        self.density_softplus = density_softplus
        self.rgb_padding = rgb_padding
        if isinstance(tree.data_format, DataFormat):
            self._data_format = None
        else:
            warn(
                "Legacy N3Tree (pre 0.2.18) without data_format, auto-infering SH deg"
            )
            # Auto SH deg
            ddim = tree.data_dim
            if ddim == 4:
                self._data_format = DataFormat("")
            else:
                self._data_format = DataFormat(f"SH{(ddim - 1) // 3}")
        self.tree._weight_accum = None