Exemple #1
0
    def test_simple_det(self):
        for n_bins, n_samples, batch in product([7, 20], [2, 7, 31, 32, 33],
                                                [(), (1, 4), (31, ), (32, ),
                                                 (33, )]):
            weights = torch.rand(size=(batch + (n_bins, )))
            bins = torch.cumsum(torch.rand(size=(batch + (n_bins + 1, ))),
                                dim=-1)
            python = sample_pdf_python(bins, weights, n_samples, det=True)

            cpp = sample_pdf(bins, weights, n_samples, det=True)
            self.assertClose(cpp, python, atol=2e-3)

            nthreads = torch.get_num_threads()
            torch.set_num_threads(1)
            cpp_singlethread = sample_pdf(bins, weights, n_samples, det=True)
            self.assertClose(cpp_singlethread, python, atol=2e-3)
            torch.set_num_threads(nthreads)

            device = torch.device("cuda:0")
            cuda = sample_pdf(bins.to(device),
                              weights.to(device),
                              n_samples,
                              det=True).cpu()

            self.assertClose(cuda, python, atol=2e-3)
Exemple #2
0
    def test_rand_nogap(self):
        # Case where random is actually deterministic
        weights = torch.FloatTensor([0, 10, 0])
        bins = torch.FloatTensor([0, 10, 10, 25])
        n_samples = 8
        predicted = torch.full((n_samples, ), 10.0)
        python = sample_pdf_python(bins, weights, n_samples)
        self.assertClose(python, predicted)
        cpp = sample_pdf(bins, weights, n_samples)
        self.assertClose(cpp, predicted)

        device = torch.device("cuda:0")
        cuda = sample_pdf(bins.to(device), weights.to(device), n_samples).cpu()
        self.assertClose(cuda, predicted)
Exemple #3
0
    def test_rand_cpu(self):
        n_bins, n_samples, batch_size = 11, 17, 9
        weights = torch.rand(size=(batch_size, n_bins))
        bins = torch.cumsum(torch.rand(size=(batch_size, n_bins + 1)), dim=-1)
        torch.manual_seed(1)
        python = sample_pdf_python(bins, weights, n_samples)
        torch.manual_seed(1)
        cpp = sample_pdf(bins, weights, n_samples)

        self.assertClose(cpp, python, atol=2e-3)
Exemple #4
0
    def forward(
        self,
        input_ray_bundle: RayBundle,
        ray_weights: torch.Tensor,
        **kwargs,
    ) -> RayBundle:
        """
        Args:
            input_ray_bundle: An instance of `RayBundle` specifying the
                source rays for sampling of the probability distribution.
            ray_weights: A tensor of shape
                `(..., input_ray_bundle.legths.shape[-1])` with non-negative
                elements defining the probability distribution to sample
                ray points from.

        Returns:
            ray_bundle: A new `RayBundle` instance containing the input ray
                points together with `n_pts_per_ray` additional sampled
                points per ray.
        """

        # Calculate the mid-points between the ray depths.
        z_vals = input_ray_bundle.lengths
        batch_size = z_vals.shape[0]

        # Carry out the importance sampling.
        with torch.no_grad():
            z_vals_mid = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
            z_samples = sample_pdf(
                z_vals_mid.view(-1, z_vals_mid.shape[-1]),
                ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
                self._n_pts_per_ray,
                det=not (
                    (self._stratified and self.training)
                    or (self._stratified_test and not self.training)
                ),
            ).view(batch_size, z_vals.shape[1], self._n_pts_per_ray)

        if self._add_input_samples:
            # Add the new samples to the input ones.
            z_vals = torch.cat((z_vals, z_samples), dim=-1)
        else:
            z_vals = z_samples
        # Resort by depth.
        z_vals, _ = torch.sort(z_vals, dim=-1)

        return RayBundle(
            origins=input_ray_bundle.origins,
            directions=input_ray_bundle.directions,
            lengths=z_vals,
            xys=input_ray_bundle.xys,
        )
Exemple #5
0
    def forward(
        self,
        input_ray_bundle: RayBundle,
        ray_weights: torch.Tensor,
        **kwargs,
    ) -> RayBundle:
        """
        Args:
            input_ray_bundle: An instance of `RayBundle` specifying the
                source rays for sampling of the probability distribution.
            ray_weights: A tensor of shape
                `(..., input_ray_bundle.legths.shape[-1])` with non-negative
                elements defining the probability distribution to sample
                ray points from.

        Returns:
            ray_bundle: A new `RayBundle` instance containing the input ray
                points together with `n_pts_per_ray` additionally sampled
                points per ray. For each ray, the lengths are sorted.
        """

        z_vals = input_ray_bundle.lengths
        with torch.no_grad():
            z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
            z_samples = sample_pdf(
                z_vals_mid.view(-1, z_vals_mid.shape[-1]),
                ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
                self.n_pts_per_ray,
                det=not self.random_sampling,
            ).view(*z_vals.shape[:-1], self.n_pts_per_ray)

        if self.add_input_samples:
            # Add the new samples to the input ones.
            z_vals = torch.cat((z_vals, z_samples), dim=-1)
        else:
            z_vals = z_samples
        # Resort by depth.
        z_vals, _ = torch.sort(z_vals, dim=-1)

        return RayBundle(
            origins=input_ray_bundle.origins,
            directions=input_ray_bundle.directions,
            lengths=z_vals,
            xys=input_ray_bundle.xys,
        )