def _expand_axes(data, points, collapse_dimensions=True): assert math.spatial_rank(data) >= 0 data = math.expand_dims(data, 1, math.spatial_rank(points) - math.spatial_rank(data)) if collapse_dimensions: return data else: points_axes = math.staticshape(points)[1:-1] data_axes = math.staticshape(data)[1:-1] for d_points, d_data in zip(points_axes, data_axes): assert d_points % d_data == 0 tilings = [1] + [d_points // d_data for d_points, d_data in zip(math.staticshape(points)[1:-1], math.staticshape(data)[1:-1])] + [1] data = math.tile(data, tilings) return data
def batch_indices(indices): """ Reshapes the indices such that, aside from indices, they also contain batch number. For example the entry (32, 40) as coordinates of batch 2 will become (2, 32, 40). Transform shape (b, p, d) to (b, p, d+1) where batch size is b, number of particles is p and number of dimensions is d. """ batch_size = indices.shape[0] out_spatial_rank = len(indices.shape) - 2 out_spatial_size = math.shape(indices)[1:-1] batch_range = math.DYNAMIC_BACKEND.choose_backend(indices).range(batch_size) batch_ids = math.reshape(batch_range, [batch_size] + [1] * out_spatial_rank) tile_shape = math.pad(out_spatial_size, [[1,0]], constant_values=1) batch_ids = math.expand_dims(math.tile(batch_ids, tile_shape), axis=-1) return math.concat((batch_ids, indices), axis=-1)
def _expand_axes(data, points, batch_size=1): assert math.spatial_rank(data) >= 0 data = math.expand_dims( data, 1, math.spatial_rank(points) - math.spatial_rank(data)) points_axes = math.staticshape(points)[1:-1] data_axes = math.staticshape(data)[1:-1] for d_points, d_data in zip(points_axes, data_axes): assert d_points % d_data == 0 tilings = [batch_size or 1] + [ d_points // d_data for d_points, d_data in zip( math.staticshape(points)[1:-1], math.staticshape(data)[1:-1]) ] + [1] data = math.tile(data, tilings) return data
def approximate_fraction_inside(self, location, cell_size): return math.tile(math.to_float(0), list(math.shape(location)[:-1]) + [1])
def lies_inside(self, location): return math.tile(False, list(math.shape(location)[:-1]) + [1])
def approximate_signed_distance(self, location): return math.tile(np.inf, list(math.shape(location)[:-1]) + [1])