def test_assert_all_in_range_exception_raised(self, dtype): """Checks that assert_all_in_range raises exceptions for invalid input.""" vector = _pick_random_vector() vector = tf.convert_to_tensor(value=vector, dtype=dtype) vector = vector * vector vector /= tf.reduce_max(input_tensor=vector, axis=-1, keepdims=True) eps = asserts.select_eps_for_addition(dtype) outside_vector = vector + eps ones_vector = tf.ones_like(vector) with self.subTest(name="outside_and_open_bounds"): with self.assertRaises(tf.errors.InvalidArgumentError): self.evaluate( asserts.assert_all_in_range(outside_vector, -1.0, 1.0, open_bounds=True)) with self.subTest(name="outside_and_close_bounds"): with self.assertRaises(tf.errors.InvalidArgumentError): self.evaluate( asserts.assert_all_in_range(outside_vector, -1.0, 1.0, open_bounds=False)) with self.subTest(name="exact_and_open_bounds"): with self.assertRaises(tf.errors.InvalidArgumentError): self.evaluate( asserts.assert_all_in_range(ones_vector, -1.0, 1.0, open_bounds=True))
def from_srgb(srgb, name=None): """Converts sRGB colors to linear colors. Note: In the following, A1 to An are optional batch dimensions. Args: srgb: A tensor of shape `[A_1, ..., A_n, 3]`, where the last dimension represents sRGB values. name: A name for this op that defaults to "srgb_to_linear". Raises: ValueError: If `srgb` has rank < 1 or has its last dimension not equal to 3. Returns: A tensor of shape `[A_1, ..., A_n, 3]`, where the last dimension represents RGB values in linear color space. """ with tf.compat.v1.name_scope(name, "linear_rgb_from_srgb", [srgb]): srgb = tf.convert_to_tensor(value=srgb) shape.check_static( tensor=srgb, tensor_name="srgb", has_rank_greater_than=0, has_dim_equals=(-1, 3)) asserts.assert_all_in_range(srgb, 0., 1.) return tf.compat.v1.where(srgb <= _K0, srgb / _PHI, ((srgb + _A) / (1 + _A))**_GAMMA)
def from_linear_rgb(linear_rgb, name=None): """Converts linear RGB to sRGB colors. Note: In the following, A1 to An are optional batch dimensions. Args: linear_rgb: A Tensor of shape `[A_1, ..., A_n, 3]`, where the last dimension represents RGB values in the range [0, 1] in linear color space. name: A name for this op that defaults to "srgb_from_linear_rgb". Raises: ValueError: If `linear_rgb` has rank < 1 or has its last dimension not equal to 3. Returns: A tensor of shape `[A_1, ..., A_n, 3]`, where the last dimension represents sRGB values. """ with tf.compat.v1.name_scope(name, "srgb_from_linear_rgb", [linear_rgb]): linear_rgb = tf.convert_to_tensor(value=linear_rgb) shape.check_static( tensor=linear_rgb, tensor_name="linear_rgb", has_rank_greater_than=0, has_dim_equals=(-1, 3)) asserts.assert_all_in_range(linear_rgb, 0., 1.) # Adds a small eps to avoid nan gradients from the second branch of # tf.where. linear_rgb += sys.float_info.epsilon return tf.compat.v1.where(linear_rgb <= _K0 / _PHI, linear_rgb * _PHI, (1 + _A) * (linear_rgb**(1 / _GAMMA)) - _A)
def test_assert_all_in_range_passthrough(self): """Checks that the assert is a passthrough when the flag is False.""" vector_input = _pick_random_vector() vector_output = asserts.assert_all_in_range(vector_input, -1.0, 1.0) self.assertIs(vector_input, vector_output)
def square_to_spherical_coordinates(point_2d, name=None): """Maps points from a unit square to a unit sphere. Note: In the following, A1 to An are optional batch dimensions. Args: point_2d: A tensor of shape `[A1, ..., An, 2]` with values in [0,1]. name: A name for this op. Defaults to "math_square_to_spherical_coordinates". Returns: A tensor of shape `[A1, ..., An, 2]` with [..., 0] having values in [0.0, pi] and [..., 1] with values in [0.0, 2pi]. Raises: ValueError: if the shape of `point_2d` is not supported. InvalidArgumentError: if at least an element of `point_2d` is outside of [0,1]. """ with tf.compat.v1.name_scope(name, "math_square_to_spherical_coordinates", [point_2d]): point_2d = tf.convert_to_tensor(value=point_2d) shape.check_static( tensor=point_2d, tensor_name="point_2d", has_dim_equals=(-1, 2)) point_2d = asserts.assert_all_in_range( point_2d, 0.0, 1.0, open_bounds=False) x, y = tf.unstack(point_2d, axis=-1) theta = 2.0 * tf.acos(tf.sqrt(1.0 - x)) phi = 2.0 * np.pi * y return tf.stack((tf.ones_like(theta), theta, phi), axis=-1)
def safe_shrink(vector, minval=None, maxval=None, open_bounds=False, eps=None, name=None): """Shrinks vector by (1.0 - eps) based on its dtype. This function shrinks the input vector by a very small amount to ensure that it is not outside of expected range because of floating point precision of operations, e.g. dot product of a normalized vector with itself can be greater than `1.0` by a small amount determined by the `dtype` of the vector. This function can be used to shrink it without affecting its derivative (unlike tf.clip_by_value) and make it safe for other operations like `acos(x)`. If the tf-graphics debug flag is set to `True`, this function adds assertions to the graph that explicitly check that the vector is in the range `[minval, maxval]` when open_bounds is `False`, or in range `]minval, maxval[` when open_bounds is `True`. Note: In the following, A1 to An are optional batch dimensions, which must be broadcast compatible. Args: vector: A tensor of shape `[A1, ..., An]`. minval: A `float` or a tensor of shape `[A1, ..., An]`, which contains the the lower bounds for tensor values after shrinking to test against. This is only used when both `minval` and `maxval` are not `None`. maxval: A `float` or a tensor of shape `[A1, ..., An]`, which contains the the upper bounds for tensor values after shrinking to test against. This is only used when both `minval` and `maxval` are not `None`. open_bounds: A `bool` indicating whether the assumed range is open or closed, only to be used when both `minval` and `maxval` are not `None`. eps: A `float` that is used to shrink the `vector`. If left as `None`, its value is automatically determined from the `dtype` of `vector`. name: A name for this op. Defaults to 'safe_shrink'. Raises: InvalidArgumentError: If tf-graphics debug flag is set and the vector is not inside the expected range. Returns: A tensor of shape `[A1, ..., An]` containing the shrinked values. """ with tf.compat.v1.name_scope(name, 'safe_shrink', [vector, minval, maxval]): vector = tf.convert_to_tensor(value=vector) if eps is None: eps = asserts.select_eps_for_addition(vector.dtype) eps = tf.convert_to_tensor(value=eps, dtype=vector.dtype) vector *= (1.0 - eps) if minval is not None and maxval is not None: vector = asserts.assert_all_in_range(vector, minval, maxval, open_bounds=open_bounds) return vector
def evaluate_legendre_polynomial(degree_l, order_m, x): """Evaluates the Legendre polynomial of degree l and order m at x. Note: This function is implementing the algorithm described in p. 10 of `Spherical Harmonic Lighting: The Gritty Details`. Note: In the following, A1 to An are optional batch dimensions. Args: degree_l: An integer tensor of shape `[A1, ..., An]` corresponding to the degree of the associated Legendre polynomial. Note that `degree_l` must be non-negative. order_m: An integer tensor of shape `[A1, ..., An]` corresponding to the order of the associated Legendre polynomial. Note that `order_m` must satisfy `0 <= order_m <= l`. x: A tensor of shape `[A1, ..., An]` with values in [-1,1]. Returns: A tensor of shape `[A1, ..., An]` containing the evaluation of the legendre polynomial. """ degree_l = tf.convert_to_tensor(value=degree_l) order_m = tf.convert_to_tensor(value=order_m) x = tf.convert_to_tensor(value=x) if not degree_l.dtype.is_integer: raise ValueError("`degree_l` must be of an integer type.") if not order_m.dtype.is_integer: raise ValueError("`order_m` must be of an integer type.") shape.compare_batch_dimensions( tensors=(degree_l, order_m, x), last_axes=-1, tensor_names=("degree_l", "order_m", "x"), broadcast_compatible=True) degree_l = asserts.assert_all_above(degree_l, 0) order_m = asserts.assert_all_in_range(order_m, 0, degree_l) x = asserts.assert_all_in_range(x, -1.0, 1.0) pmm = _evaluate_legendre_polynomial_pmm_eval(order_m, x) return tf.compat.v1.where( tf.equal(degree_l, order_m), pmm, _evaluate_legendre_polynomial_branch(degree_l, order_m, x, pmm))
def perspective_right_handed(vertical_field_of_view, aspect_ratio, near, far, name=None): """Generates the matrix for a right handed perspective projection. Note: In the following, A1 to An are optional batch dimensions. Args: vertical_field_of_view: A tensor of shape `[A1, ..., An]`, where the last dimension represents the vertical field of view of the frustum expressed in radians. Note that values for `vertical_field_of_view` must be in the range (0,pi). aspect_ratio: A tensor of shape `[A1, ..., An, C]`, where the last dimension stores the width over height ratio of the frustum. Note that values for `aspect_ratio` must be non-negative. near: A tensor of shape `[A1, ..., An, C]`, where the last dimension captures the distance between the viewer and the near clipping plane. Note that values for `near` must be non-negative. far: A tensor of shape `[A1, ..., An, C]`, where the last dimension captures the distance between the viewer and the far clipping plane. Note that values for `far` must be greater than those of `near`. name: A name for this op. Defaults to 'perspective_rh'. Raises: InvalidArgumentError: if any input contains data not in the specified range of valid values. ValueError: if the all the inputs are not of the same shape. Returns: A tensor of shape `[A1, ..., An, 4, 4]`, containing matrices of right handed perspective-view frustum. """ with tf.compat.v1.name_scope( name, "perspective_rh", [vertical_field_of_view, aspect_ratio, near, far]): vertical_field_of_view = tf.convert_to_tensor(value=vertical_field_of_view) aspect_ratio = tf.convert_to_tensor(value=aspect_ratio) near = tf.convert_to_tensor(value=near) far = tf.convert_to_tensor(value=far) shape.compare_batch_dimensions( tensors=(vertical_field_of_view, aspect_ratio, near, far), last_axes=-1, tensor_names=("vertical_field_of_view", "aspect_ratio", "near", "far"), broadcast_compatible=False) vertical_field_of_view = asserts.assert_all_in_range( vertical_field_of_view, 0.0, math.pi, open_bounds=True) aspect_ratio = asserts.assert_all_above(aspect_ratio, 0.0, open_bound=True) near = asserts.assert_all_above(near, 0.0, open_bound=True) far = asserts.assert_all_above(far, near, open_bound=True) inverse_tan_half_vertical_field_of_view = 1.0 / tf.tan( vertical_field_of_view * 0.5) zero = tf.zeros_like(inverse_tan_half_vertical_field_of_view) one = tf.ones_like(inverse_tan_half_vertical_field_of_view) x = tf.stack((inverse_tan_half_vertical_field_of_view / aspect_ratio, zero, zero, zero), axis=-1) y = tf.stack((zero, inverse_tan_half_vertical_field_of_view, zero, zero), axis=-1) near_minus_far = near - far z = tf.stack( (zero, zero, (far + near) / near_minus_far, 2.0 * far * near / near_minus_far), axis=-1) w = tf.stack((zero, zero, -one, zero), axis=-1) return tf.stack((x, y, z, w), axis=-2)
def brdf(direction_incoming_light: type_alias.TensorLike, direction_outgoing_light: type_alias.TensorLike, surface_normal: type_alias.TensorLike, albedo: type_alias.TensorLike, name: str = "lambertian_brdf") -> tf.Tensor: """Evaluates the brdf of a Lambertian surface. Note: In the following, A1 to An are optional batch dimensions, which must be broadcast compatible. Note: The gradient of this function is not smooth when the dot product of the normal with any light is 0.0. Args: direction_incoming_light: A tensor of shape `[A1, ..., An, 3]`, where the last dimension represents a normalized incoming light vector. direction_outgoing_light: A tensor of shape `[A1, ..., An, 3]`, where the last dimension represents a normalized outgoing light vector. surface_normal: A tensor of shape `[A1, ..., An, 3]`, where the last dimension represents a normalized surface normal. albedo: A tensor of shape `[A1, ..., An, 3]`, where the last dimension represents albedo with values in [0,1]. name: A name for this op. Defaults to "lambertian_brdf". Returns: A tensor of shape `[A1, ..., An, 3]`, where the last dimension represents the amount of reflected light in any outgoing direction. Raises: ValueError: if the shape of `direction_incoming_light`, `direction_outgoing_light`, `surface_normal`, `shininess` or `albedo` is not supported. InvalidArgumentError: if at least one element of `albedo` is outside of [0,1]. """ with tf.name_scope(name): direction_incoming_light = tf.convert_to_tensor( value=direction_incoming_light) direction_outgoing_light = tf.convert_to_tensor( value=direction_outgoing_light) surface_normal = tf.convert_to_tensor(value=surface_normal) albedo = tf.convert_to_tensor(value=albedo) shape.check_static(tensor=direction_incoming_light, tensor_name="direction_incoming_light", has_dim_equals=(-1, 3)) shape.check_static(tensor=direction_outgoing_light, tensor_name="direction_outgoing_light", has_dim_equals=(-1, 3)) shape.check_static(tensor=surface_normal, tensor_name="surface_normal", has_dim_equals=(-1, 3)) shape.check_static(tensor=albedo, tensor_name="albedo", has_dim_equals=(-1, 3)) shape.compare_batch_dimensions( tensors=(direction_incoming_light, direction_outgoing_light, surface_normal, albedo), tensor_names=("direction_incoming_light", "direction_outgoing_light", "surface_normal", "albedo"), last_axes=-2, broadcast_compatible=True) direction_incoming_light = asserts.assert_normalized( direction_incoming_light) direction_outgoing_light = asserts.assert_normalized( direction_outgoing_light) surface_normal = asserts.assert_normalized(surface_normal) albedo = asserts.assert_all_in_range(albedo, 0.0, 1.0, open_bounds=False) # Checks whether the incoming or outgoing light point behind the surface. dot_incoming_light_surface_normal = vector.dot( -direction_incoming_light, surface_normal) dot_outgoing_light_surface_normal = vector.dot( direction_outgoing_light, surface_normal) min_dot = tf.minimum(dot_incoming_light_surface_normal, dot_outgoing_light_surface_normal) common_shape = shape.get_broadcasted_shape(min_dot.shape, albedo.shape) d_val = lambda dim: 1 if dim is None else tf.compat.dimension_value(dim ) common_shape = [d_val(dim) for dim in common_shape] condition = tf.broadcast_to(tf.greater_equal(min_dot, 0.0), common_shape) albedo = tf.broadcast_to(albedo, common_shape) return tf.where(condition, albedo / math.pi, tf.zeros_like(albedo))
def brdf(direction_incoming_light: type_alias.TensorLike, direction_outgoing_light: type_alias.TensorLike, surface_normal: type_alias.TensorLike, shininess: type_alias.TensorLike, albedo: type_alias.TensorLike, brdf_normalization: bool = True, name: str = "phong_brdf") -> tf.Tensor: """Evaluates the specular brdf of the Phong model. Note: In the following, A1 to An are optional batch dimensions, which must be broadcast compatible. Note: The gradient of this function is not smooth when the dot product of the normal with any light is 0.0. Args: direction_incoming_light: A tensor of shape `[A1, ..., An, 3]`, where the last dimension represents a normalized incoming light vector. direction_outgoing_light: A tensor of shape `[A1, ..., An, 3]`, where the last dimension represents a normalized outgoing light vector. surface_normal: A tensor of shape `[A1, ..., An, 3]`, where the last dimension represents a normalized surface normal. shininess: A tensor of shape `[A1, ..., An, 1]`, where the last dimension represents a non-negative shininess coefficient. albedo: A tensor of shape `[A1, ..., An, 3]`, where the last dimension represents albedo with values in [0,1]. brdf_normalization: A `bool` indicating whether normalization should be applied to enforce the energy conservation property of BRDFs. Note that `brdf_normalization` must be set to False in order to use the original Blinn specular model. name: A name for this op. Defaults to "phong_brdf". Returns: A tensor of shape `[A1, ..., An, 3]`, where the last dimension represents the amount of light reflected in the outgoing light direction. Raises: ValueError: if the shape of `direction_incoming_light`, `direction_outgoing_light`, `surface_normal`, `shininess` or `albedo` is not supported. InvalidArgumentError: if not all of shininess values are non-negative, or if at least one element of `albedo` is outside of [0,1]. """ with tf.name_scope(name): direction_incoming_light = tf.convert_to_tensor( value=direction_incoming_light) direction_outgoing_light = tf.convert_to_tensor( value=direction_outgoing_light) surface_normal = tf.convert_to_tensor(value=surface_normal) shininess = tf.convert_to_tensor(value=shininess) albedo = tf.convert_to_tensor(value=albedo) shape.check_static(tensor=direction_incoming_light, tensor_name="direction_incoming_light", has_dim_equals=(-1, 3)) shape.check_static(tensor=direction_outgoing_light, tensor_name="direction_outgoing_light", has_dim_equals=(-1, 3)) shape.check_static(tensor=surface_normal, tensor_name="surface_normal", has_dim_equals=(-1, 3)) shape.check_static(tensor=shininess, tensor_name="shininess", has_dim_equals=(-1, 1)) shape.check_static(tensor=albedo, tensor_name="albedo", has_dim_equals=(-1, 3)) shape.compare_batch_dimensions( tensors=(direction_incoming_light, direction_outgoing_light, surface_normal, shininess, albedo), tensor_names=("direction_incoming_light", "direction_outgoing_light", "surface_normal", "shininess", "albedo"), last_axes=-2, broadcast_compatible=True) direction_incoming_light = asserts.assert_normalized( direction_incoming_light) direction_outgoing_light = asserts.assert_normalized( direction_outgoing_light) surface_normal = asserts.assert_normalized(surface_normal) albedo = asserts.assert_all_in_range(albedo, 0.0, 1.0, open_bounds=False) shininess = asserts.assert_all_above(shininess, 0.0, open_bound=False) # Checks whether the incoming or outgoing light point behind the surface. dot_incoming_light_surface_normal = vector.dot( -direction_incoming_light, surface_normal) dot_outgoing_light_surface_normal = vector.dot( direction_outgoing_light, surface_normal) min_dot = tf.minimum(dot_incoming_light_surface_normal, dot_outgoing_light_surface_normal) perfect_reflection_direction = vector.reflect(direction_incoming_light, surface_normal) perfect_reflection_direction = tf.math.l2_normalize( perfect_reflection_direction, axis=-1) cos_alpha = vector.dot(perfect_reflection_direction, direction_outgoing_light, axis=-1) cos_alpha = tf.maximum(cos_alpha, tf.zeros_like(cos_alpha)) phong_model = albedo * tf.pow(cos_alpha, shininess) if brdf_normalization: phong_model *= _brdf_normalization_factor(shininess) common_shape = shape.get_broadcasted_shape(min_dot.shape, phong_model.shape) d_val = lambda dim: 1 if dim is None else tf.compat.dimension_value(dim ) common_shape = [d_val(dim) for dim in common_shape] condition = tf.broadcast_to(tf.greater_equal(min_dot, 0.0), common_shape) phong_model = tf.broadcast_to(phong_model, common_shape) return tf.where(condition, phong_model, tf.zeros_like(phong_model))
def evaluate_spherical_harmonics(degree_l, order_m, theta, phi, name=None): """Evaluates a point sample of a Spherical Harmonic basis function. Note: This function is implementating the algorithm and variable names described p. 12 of 'Spherical Harmonic Lighting: The Gritty Details. Note: In the following, A1 to An are optional batch dimensions. Args: degree_l: An integer tensor of shape `[A1, ..., An, C]`, where the last dimension represents the band of the spherical harmonics. Note that `degree_l` must be non-negative. order_m: An integer tensor of shape `[A1, ..., An, C]`, where the last dimension represents the index of the spherical harmonics in the band `degree_l`. Note that `order_m` must satisfy `0 <= order_m <= l`. theta: A tensor of shape `[A1, ..., An, 1]`. This variable stores the polar angle of the sameple. Values of theta must be in [0, pi]. phi: A tensor of shape `[A1, ..., An, 1]`. This variable stores the azimuthal angle of the sameple. Values of phi must be in [0, 2pi]. name: A name for this op. Defaults to 'spherical_harmonics_evaluate_spherical_harmonics'. Returns: A tensor of shape `[A1, ..., An, C]` containing the evaluation of each basis of the spherical harmonics. Raises: ValueError: if the shape of `theta` or `phi` is not supported. InvalidArgumentError: if at least an element of `l`, `m`, `theta` or `phi` is outside the expected range. """ with tf.compat.v1.name_scope( name, "spherical_harmonics_evaluate_spherical_harmonics", [degree_l, order_m, theta, phi]): degree_l = tf.convert_to_tensor(value=degree_l) order_m = tf.convert_to_tensor(value=order_m) theta = tf.convert_to_tensor(value=theta) phi = tf.convert_to_tensor(value=phi) if not degree_l.dtype.is_integer: raise ValueError("`degree_l` must be of an integer type.") if not order_m.dtype.is_integer: raise ValueError("`order_m` must be of an integer type.") shape.compare_dimensions(tensors=(degree_l, order_m), axes=-1, tensor_names=("degree_l", "order_m")) shape.check_static(tensor=phi, tensor_name="phi", has_dim_equals=(-1, 1)) shape.check_static(tensor=theta, tensor_name="theta", has_dim_equals=(-1, 1)) shape.compare_batch_dimensions(tensors=(degree_l, order_m, theta, phi), last_axes=-2, tensor_names=("degree_l", "order_m", "theta", "phi"), broadcast_compatible=False) # Checks that tensors contain appropriate data. degree_l = asserts.assert_all_above(degree_l, 0) order_m = asserts.assert_all_in_range(order_m, -degree_l, degree_l) theta = asserts.assert_all_in_range(theta, 0.0, np.pi) phi = asserts.assert_all_in_range(phi, 0.0, 2.0 * np.pi) var_type = theta.dtype sign_m = tf.math.sign(order_m) order_m = tf.abs(order_m) zeros = tf.zeros_like(order_m) result_m_zero = _spherical_harmonics_normalization( degree_l, zeros, var_type) * evaluate_legendre_polynomial( degree_l, zeros, tf.cos(theta)) result_branch = _evaluate_spherical_harmonics_branch( degree_l, order_m, theta, phi, sign_m, var_type) return tf.where(tf.equal(order_m, zeros), result_m_zero, result_branch)
def knot_weights(positions, num_knots, degree, cyclical, sparse_mode=False, name=None): """Function that converts cardinal B-spline positions to knot weights. Note: In the following, A1 to An are optional batch dimensions. Args: positions: A tensor with shape `[A1, .. An]`. Positions must be between `[0, C - D)` for non-cyclical and `[0, C)` for cyclical splines, where `C` is the number of knots and `D` is the spline degree. num_knots: A strictly positive `int` describing the number of knots in the spline. degree: An `int` describing the degree of the spline, which must be smaller than `num_knots`. cyclical: A `bool` describing whether the spline is cyclical. sparse_mode: A `bool` describing whether to return a result only for the knots with nonzero weights. If set to True, the function returns the weights of only the `degree` + 1 knots that are non-zero, as well as the indices of the knots. name: A name for this op. Defaults to "bspline_knot_weights". Returns: A tensor with dense weights for each control point, with the shape `[A1, ... An, C]` if `sparse_mode` is False. Otherwise, returns a tensor of shape `[A1, ... An, D + 1]` that contains the non-zero weights, and a tensor with the indices of the knots, with the type tf.int32. Raises: ValueError: If degree is greater than 4 or num_knots - 1, or less than 0. InvalidArgumentError: If positions are not in the right range. """ with tf.compat.v1.name_scope(name, "bspline_knot_weights", [positions]): positions = tf.convert_to_tensor(value=positions) if degree > 4 or degree < 0: raise ValueError("Degree should be between 0 and 4.") if degree > num_knots - 1: raise ValueError("Degree cannot be >= number of knots.") if cyclical: positions = asserts.assert_all_in_range(positions, 0.0, float(num_knots)) else: positions = asserts.assert_all_in_range(positions, 0.0, float(num_knots - degree)) all_basis_functions = { # Maps valid degrees to functions. Degree.CONSTANT: _constant, Degree.LINEAR: _linear, Degree.QUADRATIC: _quadratic, Degree.CUBIC: _cubic, Degree.QUARTIC: _quartic } basis_functions = all_basis_functions[degree] if not cyclical and num_knots - degree == 1: # In this case all weights are non-zero and we can just return them. if not sparse_mode: return basis_functions(positions) else: shift = tf.zeros_like(positions, dtype=tf.int32) return basis_functions(positions), shift # shape_batch = positions.shape.as_list() shape_batch = tf.shape(input=positions) positions = tf.reshape(positions, shape=(-1,)) # Calculate the nonzero weights from the decimal parts of positions. shift = tf.floor(positions) sparse_weights = basis_functions(positions - shift) shift = tf.cast(shift, tf.int32) if sparse_mode: # Returns just the weights and the shift amounts, so that tf.gather_nd on # the knots can be used to sparsely activate knots if needed. shape_weights = tf.concat( (shape_batch, tf.constant((degree + 1,), dtype=tf.int32)), axis=0) sparse_weights = tf.reshape(sparse_weights, shape=shape_weights) shift = tf.reshape(shift, shape=shape_batch) return sparse_weights, shift num_positions = tf.size(input=positions) ind_row, ind_col = tf.meshgrid( tf.range(num_positions, dtype=tf.int32), tf.range(degree + 1, dtype=tf.int32), indexing="ij") tiled_shifts = tf.reshape( tf.tile(tf.expand_dims(shift, axis=-1), multiples=(1, degree + 1)), shape=(-1,)) ind_col = tf.reshape(ind_col, shape=(-1,)) + tiled_shifts if cyclical: ind_col = tf.math.mod(ind_col, num_knots) indices = tf.stack((tf.reshape(ind_row, shape=(-1,)), ind_col), axis=-1) shape_indices = tf.concat((tf.reshape( num_positions, shape=(1,)), tf.constant( (degree + 1, 2), dtype=tf.int32)), axis=0) indices = tf.reshape(indices, shape=shape_indices) shape_scatter = tf.concat((tf.reshape( num_positions, shape=(1,)), tf.constant((num_knots,), dtype=tf.int32)), axis=0) weights = tf.scatter_nd(indices, sparse_weights, shape_scatter) shape_weights = tf.concat( (shape_batch, tf.constant((num_knots,), dtype=tf.int32)), axis=0) return tf.reshape(weights, shape=shape_weights)