Пример #1
0
def _phi(r, order):
    eps = _ivy.array([1e-6], 'float32')
    if order % 2 == 0:
        r = _ivy.maximum(r, eps)
        return 0.5 * (r**(0.5 * order)) * _ivy.log(r)
    else:
        r = _ivy.maximum(r, eps)
        return r**(0.5 * order)
Пример #2
0
def train_step(loss_fn_in, optimizer, ntm, total_seq, target_seq, seq_len, mw,
               vw, step, max_grad_norm):
    # compute loss
    loss, dldv, pred_vals = ivy.execute_with_gradients(
        lambda v_: loss_fn_in(v_, total_seq, target_seq, seq_len), ntm.v)

    global_norm = ivy.reduce_sum(
        ivy.stack([ivy.reduce_sum(grad**2) for grad in dldv.to_flat_list()],
                  0))**0.5
    dldv = dldv.map(lambda x, _: x * max_grad_norm / ivy.maximum(
        global_norm, max_grad_norm))

    # update variables
    ntm.v = optimizer.step(ntm.v, dldv)
    return loss, pred_vals
Пример #3
0
 def _group_tensor_into_windowed_tensor_simple(self, x, seq_info):
     seq_info = self._update_seq_info_for_window(seq_info)
     if self._fixed_sequence_length:
         return ivy.reshape(ivy.gather_nd(x, ivy.array(self._gather_idxs)),
                            (self._windows_per_seq, self._window_size) +
                            x.shape[1:])
     else:
         num_windows_in_seq = int(
             ivy.to_numpy(
                 ivy.maximum(seq_info.length[0] - self._window_size + 1,
                             1)))
         window_idxs_in_seq = ivy.arange(num_windows_in_seq, 0, 1)
         gather_idxs = ivy.tile(
             ivy.reshape(ivy.arange(self._window_size, 0, 1),
                         (1, self._window_size)),
             (num_windows_in_seq, 1)) + ivy.expand_dims(
                 window_idxs_in_seq, -1)
         gather_idxs_flat = ivy.reshape(
             gather_idxs, (self._window_size * num_windows_in_seq, 1))
         return ivy.reshape(ivy.gather_nd(x, gather_idxs_flat),
                            (num_windows_in_seq, self._window_size) +
                            x.shape[1:])
Пример #4
0
def compute_length(query_vals):
    start_vals = query_vals[:, 0:-1]
    end_vals = query_vals[:, 1:]
    dists_sqrd = ivy.maximum((end_vals - start_vals)**2, 1e-12)
    distances = ivy.reduce_sum(dists_sqrd, -1)**0.5
    return ivy.reduce_mean(ivy.reduce_sum(distances, 1))
Пример #5
0
def cuboid_signed_distances(cuboid_ext_mats,
                            cuboid_dims,
                            query_positions,
                            batch_shape=None):
    """
    Return the signed distances of a set of query points from the cuboid surfaces.\n
    `[reference] <https://www.iquilezles.org/www/articles/distfunctions/distfunctions.htm>`_

    :param cuboid_ext_mats: Extrinsic matrices of the cuboids *[batch_shape,num_cuboids,3,4]*
    :type cuboid_ext_mats: array
    :param cuboid_dims: Dimensions of the cuboids, in the order x, y, z *[batch_shape,num_cuboids,3]*
    :type cuboid_dims: array
    :param query_positions: Points for which to query the signed distances *[batch_shape,num_points,3]*
    :type query_positions: array
    :param batch_shape: Shape of batch. Assumed no batches if None.
    :type batch_shape: sequence of ints, optional
    :return: The distances of the query points from the closest cuboid surface *[batch_shape,num_points,1]*
    """

    if batch_shape is None:
        batch_shape = cuboid_ext_mats.shape[:-3]

    # shapes as list
    batch_shape = list(batch_shape)
    num_batch_dims = len(batch_shape)
    batch_dims_for_trans = list(range(num_batch_dims))
    num_cuboids = cuboid_ext_mats.shape[-3]
    num_points = query_positions.shape[-2]

    # BS x 3 x NP
    query_positions_trans = _ivy.transpose(
        query_positions,
        batch_dims_for_trans + [num_batch_dims + 1, num_batch_dims])

    # BS x 1 x NP
    ones = _ivy.ones_like(query_positions_trans[..., 0:1, :])

    # BS x 4 x NP
    query_positions_trans_homo = _ivy.concatenate(
        (query_positions_trans, ones), -2)

    # BS x NCx3 x 4
    cuboid_ext_mats_flat = _ivy.reshape(cuboid_ext_mats, batch_shape + [-1, 4])

    # BS x NCx3 x NP
    rel_query_positions_trans_flat = _ivy.matmul(cuboid_ext_mats_flat,
                                                 query_positions_trans_homo)

    # BS x NC x 3 x NP
    rel_query_positions_trans = _ivy.reshape(
        rel_query_positions_trans_flat,
        batch_shape + [num_cuboids, 3, num_points])

    # BS x NC x NP x 3
    rel_query_positions = _ivy.transpose(
        rel_query_positions_trans, batch_dims_for_trans +
        [num_batch_dims, num_batch_dims + 2, num_batch_dims + 1])
    q = _ivy.abs(rel_query_positions) - _ivy.expand_dims(cuboid_dims / 2, -2)
    q_max_clipped = _ivy.maximum(q, 1e-12)

    # BS x NC x NP x 1
    q_min_clipped = _ivy.minimum(_ivy.reduce_max(q, -1, keepdims=True), 0.)
    q_max_clipped_len = _ivy.reduce_sum(q_max_clipped**2, -1,
                                        keepdims=True)**0.5
    sdfs = q_max_clipped_len + q_min_clipped

    # BS x NP x 1
    return _ivy.reduce_min(sdfs, -3)
Пример #6
0
def quantize_to_image(pixel_coords,
                      final_image_dims,
                      feat=None,
                      feat_prior=None,
                      with_db=False,
                      pixel_coords_var=1e-3,
                      feat_var=1e-3,
                      pixel_coords_prior_var=1e12,
                      feat_prior_var=1e12,
                      var_threshold=(1e-3, 1e12),
                      uniform_pixel_coords=None,
                      batch_shape=None,
                      dev_str=None):
    """
    Quantize pixel co-ordinates with d feature channels (for depth, rgb, normals etc.), from
    images :math:`\mathbf{X}\in\mathbb{R}^{input\_images\_shape×(2+d)}`, which may have been reprojected from a host of
    different cameras (leading to non-integer pixel values), to a new quantized pixel co-ordinate image with the same
    feature channels :math:`\mathbf{X}\in\mathbb{R}^{h×w×(2+d)}`, and with integer pixel co-ordinates.
    Duplicates during the quantization are either probabilistically fused based on variance, or the minimum depth is
    chosen when using depth buffer mode.

    :param pixel_coords: Pixel co-ordinates *[batch_shape,input_size,2]*
    :type pixel_coords: array
    :param final_image_dims: Image dimensions of the final image.
    :type final_image_dims: sequence of ints
    :param feat: Features (i.e. depth, rgb, encoded), default is None. *[batch_shape,input_size,d]*
    :type feat: array, optional
    :param feat_prior: Prior feature image mean, default is None. *[batch_shape,input_size,d]*
    :type feat_prior: array or float to fill with
    :param with_db: Whether or not to use depth buffer in rendering, default is false
    :type with_db: bool, optional
    :param pixel_coords_var: Pixel coords variance *[batch_shape,input_size,2]*
    :type pixel_coords_var: array or float to fill with
    :param feat_var: Feature variance *[batch_shape,input_size,d]*
    :type feat_var: array or float to fill with
    :param pixel_coords_prior_var: Pixel coords prior variance *[batch_shape,h,w,2]*
    :type pixel_coords_prior_var: array or float to fill with
    :param feat_prior_var: Features prior variance *[batch_shape,h,w,3]*
    :type feat_prior_var: array or float to fill with
    :param var_threshold: Variance threshold, for projecting valid coords and clipping *[batch_shape,2+d,2]*
    :type var_threshold: array or sequence of floats to fill with
    :param uniform_pixel_coords: Homogeneous uniform (integer) pixel co-ordinate images, inferred from final_image_dims
                                    if None *[batch_shape,h,w,3]*
    :type uniform_pixel_coords: array, optional
    :param batch_shape: Shape of batch. Assumed no batches if None.
    :type batch_shape: sequence of ints, optional
    :param dev_str: device on which to create the array 'cuda:0', 'cuda:1', 'cpu' etc. Same as x if None.
    :type dev_str: str, optional
    :return: Quantized pixel co-ordinates image with d feature channels (for depth, rgb, normals etc.) *[batch_shape,h,w,2+d]*,
             maybe the quantized variance, *[batch_shape,h,w,2+d]*, and scatter counter image *[batch_shape,h,w,1]*
    """

    # ToDo: make variance fully optional. If not specified,
    #  then do not compute and scatter during function call for better efficiency.
    # config
    if batch_shape is None:
        batch_shape = pixel_coords.shape[:-2]

    if dev_str is None:
        dev_str = _ivy.dev_str(pixel_coords)

    if feat is None:
        d = 0
    else:
        d = feat.shape[-1]
    min_depth_diff = _ivy.array([MIN_DEPTH_DIFF], dev_str=dev_str)
    red = 'min' if with_db else 'sum'

    # shapes as list
    batch_shape = list(batch_shape)
    final_image_dims = list(final_image_dims)
    num_batch_dims = len(batch_shape)

    # variance threshold
    if isinstance(var_threshold, tuple) or isinstance(var_threshold, list):
        ones = _ivy.ones(batch_shape + [1, 2 + d, 1])
        var_threshold = _ivy.concatenate(
            (ones * var_threshold[0], ones * var_threshold[1]), -1)
    else:
        var_threshold = _ivy.reshape(var_threshold,
                                     batch_shape + [1, 2 + d, 2])

    # uniform pixel coords
    if uniform_pixel_coords is None:
        uniform_pixel_coords =\
            _ivy_svg.create_uniform_pixel_coords_image(final_image_dims, batch_shape, dev_str=dev_str)
    uniform_pixel_coords = uniform_pixel_coords[..., 0:2]

    # Extract Values #

    feat_prior = _ivy.ones_like(feat) * feat_prior if isinstance(
        feat_prior, float) else feat_prior
    pixel_coords_var = _ivy.ones_like(pixel_coords) * pixel_coords_var\
        if isinstance(pixel_coords_var, float) else pixel_coords_var
    feat_var = _ivy.ones_like(feat) * feat_var if isinstance(
        feat_var, float) else feat_var
    pixel_coords_prior_var = _ivy.ones(batch_shape + final_image_dims + [2]) * pixel_coords_prior_var\
        if isinstance(pixel_coords_prior_var, float) else pixel_coords_prior_var
    feat_prior_var = _ivy.ones(batch_shape + final_image_dims + [d]) * feat_prior_var\
        if isinstance(feat_prior_var, float) else feat_prior_var

    # Quantize #

    # BS x N x 2
    quantized_pixel_coords = _ivy.reshape(
        _ivy.cast(_ivy.round(pixel_coords), 'int32'), batch_shape + [-1, 2])

    # Combine #

    # BS x N x (2+D)
    pc_n_feat = _ivy.reshape(_ivy.concatenate((pixel_coords, feat), -1),
                             batch_shape + [-1, 2 + d])
    pc_n_feat_var = _ivy.reshape(
        _ivy.concatenate((pixel_coords_var, feat_var), -1),
        batch_shape + [-1, 2 + d])

    # BS x H x W x (2+D)
    prior = _ivy.concatenate((uniform_pixel_coords, feat_prior), -1)
    prior_var = _ivy.concatenate((pixel_coords_prior_var, feat_prior_var), -1)

    # Validity Mask #

    # BS x N x 1
    var_validity_mask = \
        _ivy.reduce_sum(_ivy.cast(pc_n_feat_var < var_threshold[..., 1], 'int32'), -1, keepdims=True) == 2+d
    bounds_validity_mask = _ivy.logical_and(
        _ivy.logical_and(quantized_pixel_coords[..., 0:1] >= 0,
                         quantized_pixel_coords[..., 1:2] >= 0),
        _ivy.logical_and(
            quantized_pixel_coords[..., 0:1] <= final_image_dims[1] - 1,
            quantized_pixel_coords[..., 1:2] <= final_image_dims[0] - 1))
    validity_mask = _ivy.logical_and(var_validity_mask, bounds_validity_mask)

    # num_valid_indices x len(BS)+2
    validity_indices = _ivy.reshape(
        _ivy.cast(_ivy.indices_where(validity_mask), 'int32'),
        [-1, num_batch_dims + 2])
    num_valid_indices = validity_indices.shape[-2]

    if num_valid_indices == 0:
        return _ivy.concatenate((uniform_pixel_coords[..., 0:2], feat_prior), -1), \
               _ivy.concatenate((pixel_coords_prior_var, feat_prior_var), -1),\
               _ivy.zeros_like(feat[..., 0:1], dev_str=dev_str)

    # Depth Based Scaling #

    mean_depth_min = None
    mean_depth_range = None
    pc_n_feat_wo_depth_range = None
    pc_n_feat_wo_depth_min = None
    var_vals_range = None
    var_vals_min = None

    if with_db:

        # BS x N x 1
        mean_depth = pc_n_feat[..., 2:3]

        # BS x 1 x 1
        mean_depth_min = _ivy.reduce_min(mean_depth, -2, keepdims=True)
        mean_depth_max = _ivy.reduce_max(mean_depth, -2, keepdims=True)
        mean_depth_range = mean_depth_max - mean_depth_min

        # BS x N x 1
        scaled_depth = (mean_depth - mean_depth_min) / (
            mean_depth_range * min_depth_diff + MIN_DENOMINATOR)

        if d == 1:

            # BS x 1 x 1+D
            pc_n_feat_wo_depth_min = _ivy.zeros(batch_shape + [1, 0],
                                                dev_str=dev_str)
            pc_n_feat_wo_depth_range = _ivy.ones(batch_shape + [1, 0],
                                                 dev_str=dev_str)

        else:
            # feat without depth

            # BS x N x 1+D
            pc_n_feat_wo_depth = _ivy.concatenate(
                (pc_n_feat[..., 0:2], pc_n_feat[..., 3:]), -1)

            # find the min and max of each value

            # BS x 1 x 1+D
            pc_n_feat_wo_depth_max = _ivy.reduce_max(
                pc_n_feat_wo_depth, -2, keepdims=True) + 1
            pc_n_feat_wo_depth_min = _ivy.reduce_min(
                pc_n_feat_wo_depth, -2, keepdims=True) - 1
            pc_n_feat_wo_depth_range = pc_n_feat_wo_depth_max - pc_n_feat_wo_depth_min

            # BS x N x 1+D
            normed_pc_n_feat_wo_depth = (pc_n_feat_wo_depth - pc_n_feat_wo_depth_min) / \
                                        (pc_n_feat_wo_depth_range + MIN_DENOMINATOR)

            # combine with scaled depth

            # BS x N x 1+D
            pc_n_feat_wo_depth_scaled = normed_pc_n_feat_wo_depth + scaled_depth

            # BS x N x (2+D)
            pc_n_feat = _ivy.concatenate(
                (pc_n_feat_wo_depth_scaled[..., 0:2], mean_depth,
                 pc_n_feat_wo_depth_scaled[..., 2:]), -1)

        # scale variance

        # BS x 1 x (2+D)
        var_vals_max = _ivy.reduce_max(pc_n_feat_var, -2, keepdims=True) + 1
        var_vals_min = _ivy.reduce_min(pc_n_feat_var, -2, keepdims=True) - 1
        var_vals_range = var_vals_max - var_vals_min

        # BS x N x (2+D)
        normed_var_vals = (pc_n_feat_var - var_vals_min) / (var_vals_range +
                                                            MIN_DENOMINATOR)
        pc_n_feat_var = normed_var_vals + scaled_depth

        # ready for later reversal with full image dimensions

        # BS x 1 x 1 x D
        var_vals_min = _ivy.expand_dims(var_vals_min, -2)
        var_vals_range = _ivy.expand_dims(var_vals_range, -2)

    # Validity Pruning #

    # num_valid_indices x (2+D)
    pc_n_feat = _ivy.gather_nd(pc_n_feat,
                               validity_indices[..., 0:num_batch_dims + 1])
    pc_n_feat_var = _ivy.gather_nd(pc_n_feat_var,
                                   validity_indices[..., 0:num_batch_dims + 1])

    # num_valid_indices x 2
    quantized_pixel_coords = _ivy.gather_nd(
        quantized_pixel_coords, validity_indices[..., 0:num_batch_dims + 1])

    if with_db:
        means_to_scatter = pc_n_feat
        vars_to_scatter = pc_n_feat_var
    else:
        # num_valid_indices x (2+D)
        vars_to_scatter = 1 / (pc_n_feat_var + MIN_DENOMINATOR)
        means_to_scatter = pc_n_feat * vars_to_scatter

    # Scatter #

    # num_valid_indices x 1
    counter = _ivy.ones_like(pc_n_feat[..., 0:1], dev_str=dev_str)
    if with_db:
        counter *= -1

    # num_valid_indices x 2(2+D)+1
    values_to_scatter = _ivy.concatenate(
        (means_to_scatter, vars_to_scatter, counter), -1)

    # num_valid_indices x (num_batch_dims + 2)
    all_indices = _ivy.flip(quantized_pixel_coords, -1)
    if num_batch_dims > 0:
        all_indices = _ivy.concatenate(
            (validity_indices[..., :-2], all_indices), -1)

    # BS x H x W x (2(2+D) + 1)
    quantized_img = _ivy.scatter_nd(
        _ivy.reshape(all_indices, [-1, num_batch_dims + 2]),
        _ivy.reshape(values_to_scatter, [-1, 2 * (2 + d) + 1]),
        batch_shape + final_image_dims + [2 * (2 + d) + 1],
        reduction='replace' if _ivy.backend == 'mxnd' else red)

    # BS x H x W x 1
    quantized_counter = quantized_img[..., -1:]
    if with_db:
        invalidity_mask = quantized_counter != -1
    else:
        invalidity_mask = quantized_counter == 0

    if with_db:
        # BS x H x W x (2+D)
        quantized_mean_scaled = quantized_img[..., 0:2 + d]
        quantized_var_scaled = quantized_img[..., 2 + d:2 * (2 + d)]

        # BS x H x W x 1
        quantized_depth_mean = quantized_mean_scaled[..., 2:3]

        # BS x 1 x 1 x 1
        mean_depth_min = _ivy.expand_dims(mean_depth_min, -2)
        mean_depth_range = _ivy.expand_dims(mean_depth_range, -2)

        # BS x 1 x 1 x (1+D)
        pc_n_feat_wo_depth_min = _ivy.expand_dims(pc_n_feat_wo_depth_min, -2)
        pc_n_feat_wo_depth_range = _ivy.expand_dims(pc_n_feat_wo_depth_range,
                                                    -2)

        # BS x 1 x 1 x (2+D) x 2
        var_threshold = _ivy.expand_dims(var_threshold, -3)

        # BS x H x W x (1+D)
        quantized_mean_wo_depth_scaled = _ivy.concatenate(
            (quantized_mean_scaled[..., 0:2], quantized_mean_scaled[..., 3:]),
            -1)
        quantized_mean_wo_depth_normed = quantized_mean_wo_depth_scaled - (quantized_depth_mean - mean_depth_min) / \
                                         (mean_depth_range * min_depth_diff + MIN_DENOMINATOR)
        quantized_mean_wo_depth = quantized_mean_wo_depth_normed * pc_n_feat_wo_depth_range + pc_n_feat_wo_depth_min
        prior_wo_depth = _ivy.concatenate((prior[..., 0:2], prior[..., 3:]),
                                          -1)
        quantized_mean_wo_depth = _ivy.where(invalidity_mask, prior_wo_depth,
                                             quantized_mean_wo_depth)

        # BS x H x W x (2+D)
        quantized_mean = _ivy.concatenate(
            (quantized_mean_wo_depth[..., 0:2], quantized_depth_mean,
             quantized_mean_wo_depth[..., 2:]), -1)

        # BS x H x W x (2+D)
        quantized_var_normed = quantized_var_scaled - (quantized_depth_mean - mean_depth_min) / \
                               (mean_depth_range * min_depth_diff + MIN_DENOMINATOR)
        quantized_var = _ivy.maximum(
            quantized_var_normed * var_vals_range + var_vals_min,
            var_threshold[..., 0])
        quantized_var = _ivy.where(invalidity_mask, prior_var, quantized_var)
    else:
        # BS x H x W x (2+D)
        quantized_sum_mean_x_recip_var = quantized_img[..., 0:2 + d]
        quantized_var_wo_increase = _ivy.where(
            invalidity_mask, prior_var,
            (1 / (quantized_img[..., 2 + d:2 * (2 + d)] + MIN_DENOMINATOR)))
        quantized_var = _ivy.maximum(
            quantized_var_wo_increase * quantized_counter,
            _ivy.expand_dims(var_threshold[..., 0], -2))
        quantized_var = _ivy.where(invalidity_mask, prior_var, quantized_var)
        quantized_mean = _ivy.where(
            invalidity_mask, prior,
            quantized_var_wo_increase * quantized_sum_mean_x_recip_var)

    # BS x H x W x (2+D)    BS x H x W x (2+D)     BS x H x W x 1
    return quantized_mean, quantized_var, quantized_counter
Пример #7
0
def coords_to_voxel_grid(coords,
                         voxel_shape_spec,
                         mode='DIMS',
                         coord_bounds=None,
                         features=None,
                         batch_shape=None,
                         dev_str=None):
    """
    Create voxel grid :math:`\mathbf{X}_v\in\mathbb{R}^{x×y×z×(3+N+1)}` from homogeneous co-ordinates
    :math:`\mathbf{X}_w\in\mathbb{R}^{num\_coords×4}`. Each voxel contains 3+N+1 values: the mean normalized
    co-ordinate inside the voxel for the projected pixels with :math:`0 < x, y, z < 1`, N coordinte features (optional),
    and also the number of projected pixels inside the voxel.
    Grid resolutions and dimensions are also returned separately for each entry in the batch.
    Note that the final batched voxel grid returned uses the maximum grid dimensions across the batch, this means
    some returned grids may contain redundant space, with all but the single largest batched grid occupying a subset
    of the grid space, originating from the corner of minimum :math:`x,y,z` values.\n
    `[reference] <https://en.wikipedia.org/wiki/Voxel>`_

    :param coords: Homogeneous co-ordinates *[batch_shape,c,4]*
    :type coords: array
    :param voxel_shape_spec: Either the number of voxels in x,y,z directions, or the resolutions (metres) in x,y,z
                                directions, depending on mode. Batched or unbatched. *[batch_shape,3]* or *[3]*
    :type voxel_shape_spec: array
    :param mode: Shape specification mode, either "DIMS" or "RES"
    :type mode: str
    :param coord_bounds: Co-ordinate x, y, z boundaries *[batch_shape,6]* or *[6]*
    :type coord_bounds: array
    :param features: Co-ordinate features *[batch_shape,c,4]*.
                              E.g. RGB values, low-dimensional features, etc.
                              Features mapping to the same voxel are averaged.
    :type features: array
    :param batch_shape: Shape of batch. Inferred from inputs if None.
    :type batch_shape: sequence of ints, optional
    :param dev_str: device on which to create the array 'cuda:0', 'cuda:1', 'cpu' etc. Same as x if None.
    :type dev_str: str, optional
    :return: Voxel grid *[batch_shape,x_max,v_max,z_max,3+feature_size+1]*, dimensions *[batch_shape,3]*, resolutions *[batch_shape,3]*, voxel_grid_lower_corners *[batch_shape,3]*
    """

    if batch_shape is None:
        batch_shape = coords.shape[:-2]

    if dev_str is None:
        dev_str = _ivy.dev_str(coords)

    # shapes as list
    batch_shape = list(batch_shape)
    num_batch_dims = len(batch_shape)
    num_coords_per_batch = coords.shape[-2]

    # voxel shape spec as array
    if len(voxel_shape_spec) is 3:

        # BS x 1 x 3
        voxel_shape_spec = _ivy.expand_dims(
            _ivy.tile(
                _ivy.reshape(_ivy.array(voxel_shape_spec),
                             [1] * num_batch_dims + [3]), batch_shape + [1]),
            -2)

    # coord bounds spec as array
    if coord_bounds is not None and len(coord_bounds) is 6:

        # BS x 1 x 6
        coord_bounds = _ivy.expand_dims(
            _ivy.tile(
                _ivy.reshape(_ivy.array(coord_bounds, dtype_str='float32'),
                             [1] * num_batch_dims + [6]), batch_shape + [1]),
            -2)

    # BS x N x 3
    coords = coords[..., 0:3]

    if coord_bounds is not None:

        # BS x 1
        x_min = coord_bounds[..., 0:1]
        y_min = coord_bounds[..., 1:2]
        z_min = coord_bounds[..., 2:3]
        x_max = coord_bounds[..., 3:4]
        y_max = coord_bounds[..., 4:5]
        z_max = coord_bounds[..., 5:6]

        # BS x N x 1
        x_coords = coords[..., 0:1]
        y_coords = coords[..., 1:2]
        z_coords = coords[..., 2:3]

        x_validity_mask = _ivy.logical_and(x_coords > x_min, x_coords < x_max)
        y_validity_mask = _ivy.logical_and(y_coords > y_min, y_coords < y_max)
        z_validity_mask = _ivy.logical_and(z_coords > z_min, z_coords < z_max)

        # BS x N
        full_validity_mask = _ivy.logical_and(
            _ivy.logical_and(x_validity_mask, y_validity_mask),
            z_validity_mask)[..., 0]

        # BS x 1 x 3
        bb_mins = coord_bounds[..., 0:3]
        bb_maxs = coord_bounds[..., 3:6]
        bb_ranges = bb_maxs - bb_mins
    else:

        # BS x N
        full_validity_mask = _ivy.cast(
            _ivy.ones(batch_shape + [num_coords_per_batch]), 'bool')

        # BS x 1 x 3
        bb_mins = _ivy.reduce_min(coords, axis=-2, keepdims=True)
        bb_maxs = _ivy.reduce_max(coords, axis=-2, keepdims=True)
        bb_ranges = bb_maxs - bb_mins

    # get voxel dimensions
    if mode is 'DIMS':
        # BS x 1 x 3
        dims = _ivy.cast(voxel_shape_spec, 'int32')
    elif mode is 'RES':
        # BS x 1 x 3
        res = _ivy.cast(voxel_shape_spec, 'float32')
        dims = _ivy.cast(_ivy.ceil(bb_ranges / (res + MIN_DENOMINATOR)),
                         'int32')
    else:
        raise Exception(
            'Invalid mode selection. Must be either "DIMS" or "RES"')
    dims_m_one = _ivy.cast(dims - 1, 'int32')

    # BS x 1 x 3
    res = bb_ranges / (_ivy.cast(dims, 'float32') + MIN_DENOMINATOR)

    # BS x NC x 3
    voxel_indices = _ivy.minimum(
        _ivy.cast(_ivy.floor((coords - bb_mins) / (res + MIN_DENOMINATOR)),
                  'int32'), dims_m_one)

    # BS x NC x 3
    voxel_values = ((coords - bb_mins) % res) / (res + MIN_DENOMINATOR)

    feature_size = 0
    if features is not None:
        feature_size = features.shape[-1]
        voxel_values = _ivy.concatenate([voxel_values, features], axis=-1)

    # TNVC x len(BS)+1
    valid_coord_indices = _ivy.cast(_ivy.indices_where(full_validity_mask),
                                    'int32')

    # scalar
    total_num_valid_coords = valid_coord_indices.shape[0]

    # TNVC x 3
    voxel_values_pruned_flat = _ivy.gather_nd(voxel_values,
                                              valid_coord_indices)
    voxel_indices_pruned_flat = _ivy.gather_nd(voxel_indices,
                                               valid_coord_indices)

    # TNVC x len(BS)+2
    if num_batch_dims == 0:
        all_indices_pruned_flat = voxel_indices_pruned_flat
    else:
        batch_indices = valid_coord_indices[..., :-1]
        all_indices_pruned_flat = _ivy.concatenate(
            [batch_indices] + [voxel_indices_pruned_flat], -1)

    # TNVC x 4
    voxel_values_pruned_flat =\
        _ivy.concatenate((voxel_values_pruned_flat, _ivy.ones([total_num_valid_coords, 1], dev_str=dev_str)), -1)

    # get max dims list for scatter
    if num_batch_dims > 0:
        max_dims = _ivy.reduce_max(_ivy.reshape(dims, batch_shape + [3]),
                                   axis=list(range(num_batch_dims)))
    else:
        max_dims = _ivy.reshape(dims, batch_shape + [3])
    batch_shape_array_list = [_ivy.array(batch_shape, 'int32')
                              ] if num_batch_dims != 0 else []
    total_dims_list = _ivy.to_list(
        _ivy.concatenate(
            batch_shape_array_list +
            [max_dims, _ivy.array([4 + feature_size], 'int32')], -1))

    # BS x x_max x y_max x z_max x 4
    scattered = _ivy.scatter_nd(
        all_indices_pruned_flat,
        voxel_values_pruned_flat,
        total_dims_list,
        reduction='replace' if _ivy.backend == 'mxnd' else 'sum')

    # BS x x_max x y_max x z_max x 4 + feature_size, BS x 3, BS x 3, BS x 3
    return _ivy.concatenate(
        (scattered[..., :-1] /
         (_ivy.maximum(scattered[..., -1:], 1.) + MIN_DENOMINATOR),
         scattered[..., -1:]),
        -1), dims[..., 0, :], res[..., 0, :], bb_mins[..., 0, :]