Пример #1
0
    def test_assert_all_in_range_exception_raised(self, dtype):
        """Checks that assert_all_in_range raises exceptions for invalid input."""
        vector = _pick_random_vector()
        vector = tf.convert_to_tensor(value=vector, dtype=dtype)

        vector = vector * vector
        vector /= tf.reduce_max(input_tensor=vector, axis=-1, keepdims=True)
        eps = asserts.select_eps_for_addition(dtype)
        outside_vector = vector + eps
        ones_vector = tf.ones_like(vector)

        with self.subTest(name="outside_and_open_bounds"):
            with self.assertRaises(tf.errors.InvalidArgumentError):
                self.evaluate(
                    asserts.assert_all_in_range(outside_vector,
                                                -1.0,
                                                1.0,
                                                open_bounds=True))

        with self.subTest(name="outside_and_close_bounds"):
            with self.assertRaises(tf.errors.InvalidArgumentError):
                self.evaluate(
                    asserts.assert_all_in_range(outside_vector,
                                                -1.0,
                                                1.0,
                                                open_bounds=False))

        with self.subTest(name="exact_and_open_bounds"):
            with self.assertRaises(tf.errors.InvalidArgumentError):
                self.evaluate(
                    asserts.assert_all_in_range(ones_vector,
                                                -1.0,
                                                1.0,
                                                open_bounds=True))
Пример #2
0
def from_srgb(srgb, name=None):
  """Converts sRGB colors to linear colors.

  Note:
      In the following, A1 to An are optional batch dimensions.

  Args:
    srgb: A tensor of shape `[A_1, ..., A_n, 3]`, where the last dimension
      represents sRGB values.
    name: A name for this op that defaults to "srgb_to_linear".

  Raises:
    ValueError: If `srgb` has rank < 1 or has its last dimension not equal to 3.

  Returns:
    A tensor of shape `[A_1, ..., A_n, 3]`, where the last dimension represents
    RGB values in linear color space.
  """
  with tf.compat.v1.name_scope(name, "linear_rgb_from_srgb", [srgb]):
    srgb = tf.convert_to_tensor(value=srgb)

    shape.check_static(
        tensor=srgb,
        tensor_name="srgb",
        has_rank_greater_than=0,
        has_dim_equals=(-1, 3))

    asserts.assert_all_in_range(srgb, 0., 1.)
    return tf.compat.v1.where(srgb <= _K0, srgb / _PHI,
                              ((srgb + _A) / (1 + _A))**_GAMMA)
Пример #3
0
def from_linear_rgb(linear_rgb, name=None):
  """Converts linear RGB to sRGB colors.

  Note:
      In the following, A1 to An are optional batch dimensions.

  Args:
    linear_rgb: A Tensor of shape `[A_1, ..., A_n, 3]`, where the last dimension
      represents RGB values in the range [0, 1] in linear color space.
    name: A name for this op that defaults to "srgb_from_linear_rgb".

  Raises:
    ValueError: If `linear_rgb` has rank < 1 or has its last dimension not
      equal to 3.

  Returns:
    A tensor of shape `[A_1, ..., A_n, 3]`, where the last dimension represents
    sRGB values.
  """
  with tf.compat.v1.name_scope(name, "srgb_from_linear_rgb", [linear_rgb]):
    linear_rgb = tf.convert_to_tensor(value=linear_rgb)

    shape.check_static(
        tensor=linear_rgb,
        tensor_name="linear_rgb",
        has_rank_greater_than=0,
        has_dim_equals=(-1, 3))
    asserts.assert_all_in_range(linear_rgb, 0., 1.)

    # Adds a small eps to avoid nan gradients from the second branch of
    # tf.where.
    linear_rgb += sys.float_info.epsilon
    return tf.compat.v1.where(linear_rgb <= _K0 / _PHI, linear_rgb * _PHI,
                              (1 + _A) * (linear_rgb**(1 / _GAMMA)) - _A)
Пример #4
0
    def test_assert_all_in_range_passthrough(self):
        """Checks that the assert is a passthrough when the flag is False."""
        vector_input = _pick_random_vector()

        vector_output = asserts.assert_all_in_range(vector_input, -1.0, 1.0)

        self.assertIs(vector_input, vector_output)
Пример #5
0
def square_to_spherical_coordinates(point_2d, name=None):
  """Maps points from a unit square to a unit sphere.

  Note:
    In the following, A1 to An are optional batch dimensions.

  Args:
    point_2d: A tensor of shape `[A1, ..., An, 2]` with values in [0,1].
    name: A name for this op. Defaults to
      "math_square_to_spherical_coordinates".

  Returns:
    A tensor of shape `[A1, ..., An, 2]` with [..., 0] having values in
    [0.0, pi] and [..., 1] with values in [0.0, 2pi].

  Raises:
    ValueError: if the shape of `point_2d`  is not supported.
    InvalidArgumentError: if at least an element of `point_2d` is outside of
    [0,1].

  """
  with tf.compat.v1.name_scope(name, "math_square_to_spherical_coordinates",
                               [point_2d]):
    point_2d = tf.convert_to_tensor(value=point_2d)

    shape.check_static(
        tensor=point_2d, tensor_name="point_2d", has_dim_equals=(-1, 2))
    point_2d = asserts.assert_all_in_range(
        point_2d, 0.0, 1.0, open_bounds=False)

    x, y = tf.unstack(point_2d, axis=-1)
    theta = 2.0 * tf.acos(tf.sqrt(1.0 - x))
    phi = 2.0 * np.pi * y
    return tf.stack((tf.ones_like(theta), theta, phi), axis=-1)
Пример #6
0
def safe_shrink(vector,
                minval=None,
                maxval=None,
                open_bounds=False,
                eps=None,
                name=None):
    """Shrinks vector by (1.0 - eps) based on its dtype.

  This function shrinks the input vector by a very small amount to ensure that
  it is not outside of expected range because of floating point precision
  of operations, e.g. dot product of a normalized vector with itself can
  be greater than `1.0` by a small amount determined by the `dtype` of the
  vector. This function can be used to shrink it without affecting its
  derivative (unlike tf.clip_by_value) and make it safe for other operations
  like `acos(x)`. If the tf-graphics debug flag is set to `True`, this function
  adds assertions to the graph that explicitly check that the vector is in the
  range `[minval, maxval]` when open_bounds is `False`, or in range `]minval,
  maxval[` when open_bounds is `True`.

  Note:
    In the following, A1 to An are optional batch dimensions, which must be
    broadcast compatible.

  Args:
    vector: A tensor of shape `[A1, ..., An]`.
    minval: A `float` or a tensor of shape `[A1, ..., An]`, which contains the
      the lower bounds for tensor values after shrinking to test against. This
      is only used when both `minval` and `maxval` are not `None`.
    maxval: A `float` or a tensor of shape `[A1, ..., An]`, which contains the
      the upper bounds for tensor values after shrinking to test against. This
      is only used when both `minval` and `maxval` are not `None`.
    open_bounds: A `bool` indicating whether the assumed range is open or
      closed, only to be used when both `minval` and `maxval` are not `None`.
    eps: A `float` that is used to shrink the `vector`. If left as `None`, its
      value is automatically determined from the `dtype` of `vector`.
    name: A name for this op. Defaults to 'safe_shrink'.

  Raises:
    InvalidArgumentError: If tf-graphics debug flag is set and the vector is not
      inside the expected range.

  Returns:
    A tensor of shape `[A1, ..., An]` containing the shrinked values.
  """
    with tf.compat.v1.name_scope(name, 'safe_shrink',
                                 [vector, minval, maxval]):
        vector = tf.convert_to_tensor(value=vector)
        if eps is None:
            eps = asserts.select_eps_for_addition(vector.dtype)
        eps = tf.convert_to_tensor(value=eps, dtype=vector.dtype)

        vector *= (1.0 - eps)
        if minval is not None and maxval is not None:
            vector = asserts.assert_all_in_range(vector,
                                                 minval,
                                                 maxval,
                                                 open_bounds=open_bounds)
        return vector
Пример #7
0
def evaluate_legendre_polynomial(degree_l, order_m, x):
  """Evaluates the Legendre polynomial of degree l and order m at x.

  Note:
    This function is implementing the algorithm described in p. 10 of `Spherical
    Harmonic Lighting: The Gritty Details`.

  Note:
    In the following, A1 to An are optional batch dimensions.

  Args:
    degree_l: An integer tensor of shape `[A1, ..., An]` corresponding to the
      degree of the associated Legendre polynomial. Note that `degree_l` must be
      non-negative.
    order_m: An integer tensor of shape `[A1, ..., An]` corresponding to the
      order of the associated Legendre polynomial. Note that `order_m` must
      satisfy `0 <= order_m <= l`.
    x: A tensor of shape `[A1, ..., An]` with values in [-1,1].

  Returns:
    A tensor of shape `[A1, ..., An]` containing the evaluation of the legendre
    polynomial.
  """
  degree_l = tf.convert_to_tensor(value=degree_l)
  order_m = tf.convert_to_tensor(value=order_m)
  x = tf.convert_to_tensor(value=x)

  if not degree_l.dtype.is_integer:
    raise ValueError("`degree_l` must be of an integer type.")
  if not order_m.dtype.is_integer:
    raise ValueError("`order_m` must be of an integer type.")
  shape.compare_batch_dimensions(
      tensors=(degree_l, order_m, x),
      last_axes=-1,
      tensor_names=("degree_l", "order_m", "x"),
      broadcast_compatible=True)
  degree_l = asserts.assert_all_above(degree_l, 0)
  order_m = asserts.assert_all_in_range(order_m, 0, degree_l)
  x = asserts.assert_all_in_range(x, -1.0, 1.0)

  pmm = _evaluate_legendre_polynomial_pmm_eval(order_m, x)
  return tf.compat.v1.where(
      tf.equal(degree_l, order_m), pmm,
      _evaluate_legendre_polynomial_branch(degree_l, order_m, x, pmm))
Пример #8
0
def perspective_right_handed(vertical_field_of_view,
                             aspect_ratio,
                             near,
                             far,
                             name=None):
  """Generates the matrix for a right handed perspective projection.

  Note:
    In the following, A1 to An are optional batch dimensions.

  Args:
    vertical_field_of_view: A tensor of shape `[A1, ..., An]`, where the last
      dimension represents the vertical field of view of the frustum expressed
      in radians. Note that values for `vertical_field_of_view` must be in the
      range (0,pi).
    aspect_ratio: A tensor of shape `[A1, ..., An, C]`, where the last dimension
      stores the width over height ratio of the frustum. Note that values for
      `aspect_ratio` must be non-negative.
    near:  A tensor of shape `[A1, ..., An, C]`, where the last dimension
      captures the distance between the viewer and the near clipping plane. Note
      that values for `near` must be non-negative.
    far:  A tensor of shape `[A1, ..., An, C]`, where the last dimension
      captures the distance between the viewer and the far clipping plane. Note
      that values for `far` must be greater than those of `near`.
    name: A name for this op. Defaults to 'perspective_rh'.

  Raises:
    InvalidArgumentError: if any input contains data not in the specified range
      of valid values.
    ValueError: if the all the inputs are not of the same shape.

  Returns:
    A tensor of shape `[A1, ..., An, 4, 4]`, containing matrices of right
    handed perspective-view frustum.
  """
  with tf.compat.v1.name_scope(
      name, "perspective_rh",
      [vertical_field_of_view, aspect_ratio, near, far]):
    vertical_field_of_view = tf.convert_to_tensor(value=vertical_field_of_view)
    aspect_ratio = tf.convert_to_tensor(value=aspect_ratio)
    near = tf.convert_to_tensor(value=near)
    far = tf.convert_to_tensor(value=far)

    shape.compare_batch_dimensions(
        tensors=(vertical_field_of_view, aspect_ratio, near, far),
        last_axes=-1,
        tensor_names=("vertical_field_of_view", "aspect_ratio", "near", "far"),
        broadcast_compatible=False)

    vertical_field_of_view = asserts.assert_all_in_range(
        vertical_field_of_view, 0.0, math.pi, open_bounds=True)
    aspect_ratio = asserts.assert_all_above(aspect_ratio, 0.0, open_bound=True)
    near = asserts.assert_all_above(near, 0.0, open_bound=True)
    far = asserts.assert_all_above(far, near, open_bound=True)

    inverse_tan_half_vertical_field_of_view = 1.0 / tf.tan(
        vertical_field_of_view * 0.5)
    zero = tf.zeros_like(inverse_tan_half_vertical_field_of_view)
    one = tf.ones_like(inverse_tan_half_vertical_field_of_view)

    x = tf.stack((inverse_tan_half_vertical_field_of_view / aspect_ratio, zero,
                  zero, zero),
                 axis=-1)
    y = tf.stack((zero, inverse_tan_half_vertical_field_of_view, zero, zero),
                 axis=-1)
    near_minus_far = near - far
    z = tf.stack(
        (zero, zero,
         (far + near) / near_minus_far, 2.0 * far * near / near_minus_far),
        axis=-1)
    w = tf.stack((zero, zero, -one, zero), axis=-1)

    return tf.stack((x, y, z, w), axis=-2)
Пример #9
0
def brdf(direction_incoming_light: type_alias.TensorLike,
         direction_outgoing_light: type_alias.TensorLike,
         surface_normal: type_alias.TensorLike,
         albedo: type_alias.TensorLike,
         name: str = "lambertian_brdf") -> tf.Tensor:
    """Evaluates the brdf of a Lambertian surface.

  Note:
    In the following, A1 to An are optional batch dimensions, which must be
    broadcast compatible.

  Note:
    The gradient of this function is not smooth when the dot product of the
    normal with any light is 0.0.

  Args:
    direction_incoming_light: A tensor of shape `[A1, ..., An, 3]`, where the
      last dimension represents a normalized incoming light vector.
    direction_outgoing_light: A tensor of shape `[A1, ..., An, 3]`, where the
      last dimension represents a normalized outgoing light vector.
    surface_normal: A tensor of shape `[A1, ..., An, 3]`, where the last
      dimension represents a normalized surface normal.
    albedo: A tensor of shape `[A1, ..., An, 3]`, where the last dimension
      represents albedo with values in [0,1].
    name: A name for this op. Defaults to "lambertian_brdf".

  Returns:
    A tensor of shape `[A1, ..., An, 3]`, where the last dimension represents
      the amount of reflected light in any outgoing direction.

  Raises:
    ValueError: if the shape of `direction_incoming_light`,
    `direction_outgoing_light`, `surface_normal`, `shininess` or `albedo` is not
    supported.
    InvalidArgumentError: if at least one element of `albedo` is outside of
    [0,1].
  """
    with tf.name_scope(name):
        direction_incoming_light = tf.convert_to_tensor(
            value=direction_incoming_light)
        direction_outgoing_light = tf.convert_to_tensor(
            value=direction_outgoing_light)
        surface_normal = tf.convert_to_tensor(value=surface_normal)
        albedo = tf.convert_to_tensor(value=albedo)

        shape.check_static(tensor=direction_incoming_light,
                           tensor_name="direction_incoming_light",
                           has_dim_equals=(-1, 3))
        shape.check_static(tensor=direction_outgoing_light,
                           tensor_name="direction_outgoing_light",
                           has_dim_equals=(-1, 3))
        shape.check_static(tensor=surface_normal,
                           tensor_name="surface_normal",
                           has_dim_equals=(-1, 3))
        shape.check_static(tensor=albedo,
                           tensor_name="albedo",
                           has_dim_equals=(-1, 3))
        shape.compare_batch_dimensions(
            tensors=(direction_incoming_light, direction_outgoing_light,
                     surface_normal, albedo),
            tensor_names=("direction_incoming_light",
                          "direction_outgoing_light", "surface_normal",
                          "albedo"),
            last_axes=-2,
            broadcast_compatible=True)
        direction_incoming_light = asserts.assert_normalized(
            direction_incoming_light)
        direction_outgoing_light = asserts.assert_normalized(
            direction_outgoing_light)
        surface_normal = asserts.assert_normalized(surface_normal)
        albedo = asserts.assert_all_in_range(albedo,
                                             0.0,
                                             1.0,
                                             open_bounds=False)

        # Checks whether the incoming or outgoing light point behind the surface.
        dot_incoming_light_surface_normal = vector.dot(
            -direction_incoming_light, surface_normal)
        dot_outgoing_light_surface_normal = vector.dot(
            direction_outgoing_light, surface_normal)
        min_dot = tf.minimum(dot_incoming_light_surface_normal,
                             dot_outgoing_light_surface_normal)
        common_shape = shape.get_broadcasted_shape(min_dot.shape, albedo.shape)
        d_val = lambda dim: 1 if dim is None else tf.compat.dimension_value(dim
                                                                            )
        common_shape = [d_val(dim) for dim in common_shape]
        condition = tf.broadcast_to(tf.greater_equal(min_dot, 0.0),
                                    common_shape)
        albedo = tf.broadcast_to(albedo, common_shape)
        return tf.where(condition, albedo / math.pi, tf.zeros_like(albedo))
Пример #10
0
def brdf(direction_incoming_light: type_alias.TensorLike,
         direction_outgoing_light: type_alias.TensorLike,
         surface_normal: type_alias.TensorLike,
         shininess: type_alias.TensorLike,
         albedo: type_alias.TensorLike,
         brdf_normalization: bool = True,
         name: str = "phong_brdf") -> tf.Tensor:
    """Evaluates the specular brdf of the Phong model.

  Note:
    In the following, A1 to An are optional batch dimensions, which must be
    broadcast compatible.

  Note:
    The gradient of this function is not smooth when the dot product of the
    normal with any light is 0.0.

  Args:
    direction_incoming_light: A tensor of shape `[A1, ..., An, 3]`, where the
      last dimension represents a normalized incoming light vector.
    direction_outgoing_light: A tensor of shape `[A1, ..., An, 3]`, where the
      last dimension represents a normalized outgoing light vector.
    surface_normal: A tensor of shape `[A1, ..., An, 3]`, where the last
      dimension represents a normalized surface normal.
    shininess: A tensor of shape `[A1, ..., An, 1]`, where the last dimension
      represents a non-negative shininess coefficient.
    albedo: A tensor of shape `[A1, ..., An, 3]`, where the last dimension
      represents albedo with values in [0,1].
    brdf_normalization: A `bool` indicating whether normalization should be
      applied to enforce the energy conservation property of BRDFs. Note that
      `brdf_normalization` must be set to False in order to use the original
      Blinn specular model.
    name: A name for this op. Defaults to "phong_brdf".

  Returns:
    A tensor of shape `[A1, ..., An, 3]`, where the last dimension represents
      the amount of light reflected in the outgoing light direction.

  Raises:
    ValueError: if the shape of `direction_incoming_light`,
    `direction_outgoing_light`, `surface_normal`, `shininess` or `albedo` is not
    supported.
    InvalidArgumentError: if not all of shininess values are non-negative, or if
    at least one element of `albedo` is outside of [0,1].
  """
    with tf.name_scope(name):
        direction_incoming_light = tf.convert_to_tensor(
            value=direction_incoming_light)
        direction_outgoing_light = tf.convert_to_tensor(
            value=direction_outgoing_light)
        surface_normal = tf.convert_to_tensor(value=surface_normal)
        shininess = tf.convert_to_tensor(value=shininess)
        albedo = tf.convert_to_tensor(value=albedo)

        shape.check_static(tensor=direction_incoming_light,
                           tensor_name="direction_incoming_light",
                           has_dim_equals=(-1, 3))
        shape.check_static(tensor=direction_outgoing_light,
                           tensor_name="direction_outgoing_light",
                           has_dim_equals=(-1, 3))
        shape.check_static(tensor=surface_normal,
                           tensor_name="surface_normal",
                           has_dim_equals=(-1, 3))
        shape.check_static(tensor=shininess,
                           tensor_name="shininess",
                           has_dim_equals=(-1, 1))
        shape.check_static(tensor=albedo,
                           tensor_name="albedo",
                           has_dim_equals=(-1, 3))
        shape.compare_batch_dimensions(
            tensors=(direction_incoming_light, direction_outgoing_light,
                     surface_normal, shininess, albedo),
            tensor_names=("direction_incoming_light",
                          "direction_outgoing_light", "surface_normal",
                          "shininess", "albedo"),
            last_axes=-2,
            broadcast_compatible=True)
        direction_incoming_light = asserts.assert_normalized(
            direction_incoming_light)
        direction_outgoing_light = asserts.assert_normalized(
            direction_outgoing_light)
        surface_normal = asserts.assert_normalized(surface_normal)
        albedo = asserts.assert_all_in_range(albedo,
                                             0.0,
                                             1.0,
                                             open_bounds=False)
        shininess = asserts.assert_all_above(shininess, 0.0, open_bound=False)

        # Checks whether the incoming or outgoing light point behind the surface.
        dot_incoming_light_surface_normal = vector.dot(
            -direction_incoming_light, surface_normal)
        dot_outgoing_light_surface_normal = vector.dot(
            direction_outgoing_light, surface_normal)
        min_dot = tf.minimum(dot_incoming_light_surface_normal,
                             dot_outgoing_light_surface_normal)
        perfect_reflection_direction = vector.reflect(direction_incoming_light,
                                                      surface_normal)
        perfect_reflection_direction = tf.math.l2_normalize(
            perfect_reflection_direction, axis=-1)
        cos_alpha = vector.dot(perfect_reflection_direction,
                               direction_outgoing_light,
                               axis=-1)
        cos_alpha = tf.maximum(cos_alpha, tf.zeros_like(cos_alpha))
        phong_model = albedo * tf.pow(cos_alpha, shininess)
        if brdf_normalization:
            phong_model *= _brdf_normalization_factor(shininess)
        common_shape = shape.get_broadcasted_shape(min_dot.shape,
                                                   phong_model.shape)
        d_val = lambda dim: 1 if dim is None else tf.compat.dimension_value(dim
                                                                            )
        common_shape = [d_val(dim) for dim in common_shape]
        condition = tf.broadcast_to(tf.greater_equal(min_dot, 0.0),
                                    common_shape)
        phong_model = tf.broadcast_to(phong_model, common_shape)
        return tf.where(condition, phong_model, tf.zeros_like(phong_model))
Пример #11
0
def evaluate_spherical_harmonics(degree_l, order_m, theta, phi, name=None):
    """Evaluates a point sample of a Spherical Harmonic basis function.

  Note:
    This function is implementating the algorithm and variable names described
    p. 12 of 'Spherical Harmonic Lighting: The Gritty Details.

  Note:
    In the following, A1 to An are optional batch dimensions.

  Args:
    degree_l: An integer tensor of shape `[A1, ..., An, C]`, where the last
      dimension represents the band of the spherical harmonics. Note that
      `degree_l` must be non-negative.
    order_m: An integer tensor of shape `[A1, ..., An, C]`, where the last
      dimension represents the index of the spherical harmonics in the band
      `degree_l`. Note that `order_m` must satisfy `0 <= order_m <= l`.
    theta: A tensor of shape `[A1, ..., An, 1]`. This variable stores the polar
      angle of the sameple. Values of theta must be in [0, pi].
    phi: A tensor of shape `[A1, ..., An, 1]`. This variable stores the
      azimuthal angle of the sameple. Values of phi must be in [0, 2pi].
    name: A name for this op. Defaults to
      'spherical_harmonics_evaluate_spherical_harmonics'.

  Returns:
    A tensor of shape `[A1, ..., An, C]` containing the evaluation of each basis
    of the spherical harmonics.

  Raises:
    ValueError: if the shape of `theta` or `phi` is not supported.
    InvalidArgumentError: if at least an element of `l`, `m`, `theta` or `phi`
    is outside the expected range.
  """
    with tf.compat.v1.name_scope(
            name, "spherical_harmonics_evaluate_spherical_harmonics",
        [degree_l, order_m, theta, phi]):
        degree_l = tf.convert_to_tensor(value=degree_l)
        order_m = tf.convert_to_tensor(value=order_m)
        theta = tf.convert_to_tensor(value=theta)
        phi = tf.convert_to_tensor(value=phi)

        if not degree_l.dtype.is_integer:
            raise ValueError("`degree_l` must be of an integer type.")
        if not order_m.dtype.is_integer:
            raise ValueError("`order_m` must be of an integer type.")

        shape.compare_dimensions(tensors=(degree_l, order_m),
                                 axes=-1,
                                 tensor_names=("degree_l", "order_m"))
        shape.check_static(tensor=phi,
                           tensor_name="phi",
                           has_dim_equals=(-1, 1))
        shape.check_static(tensor=theta,
                           tensor_name="theta",
                           has_dim_equals=(-1, 1))
        shape.compare_batch_dimensions(tensors=(degree_l, order_m, theta, phi),
                                       last_axes=-2,
                                       tensor_names=("degree_l", "order_m",
                                                     "theta", "phi"),
                                       broadcast_compatible=False)
        # Checks that tensors contain appropriate data.
        degree_l = asserts.assert_all_above(degree_l, 0)
        order_m = asserts.assert_all_in_range(order_m, -degree_l, degree_l)
        theta = asserts.assert_all_in_range(theta, 0.0, np.pi)
        phi = asserts.assert_all_in_range(phi, 0.0, 2.0 * np.pi)

        var_type = theta.dtype
        sign_m = tf.math.sign(order_m)
        order_m = tf.abs(order_m)
        zeros = tf.zeros_like(order_m)
        result_m_zero = _spherical_harmonics_normalization(
            degree_l, zeros, var_type) * evaluate_legendre_polynomial(
                degree_l, zeros, tf.cos(theta))
        result_branch = _evaluate_spherical_harmonics_branch(
            degree_l, order_m, theta, phi, sign_m, var_type)
        return tf.where(tf.equal(order_m, zeros), result_m_zero, result_branch)
Пример #12
0
def knot_weights(positions,
                 num_knots,
                 degree,
                 cyclical,
                 sparse_mode=False,
                 name=None):
  """Function that converts cardinal B-spline positions to knot weights.

  Note:
    In the following, A1 to An are optional batch dimensions.

  Args:
    positions: A tensor with shape `[A1, .. An]`. Positions must be between
      `[0, C - D)` for non-cyclical and `[0, C)` for cyclical splines, where `C`
      is the number of knots and `D` is the spline degree.
    num_knots: A strictly positive `int` describing the number of knots in the
      spline.
    degree: An `int` describing the degree of the spline, which must be smaller
      than `num_knots`.
    cyclical: A `bool` describing whether the spline is cyclical.
    sparse_mode: A `bool` describing whether to return a result only for the
      knots with nonzero weights. If set to True, the function returns the
      weights of only the `degree` + 1 knots that are non-zero, as well as the
      indices of the knots.
    name: A name for this op. Defaults to "bspline_knot_weights".

  Returns:
    A tensor with dense weights for each control point, with the shape
    `[A1, ... An, C]` if `sparse_mode` is False.
    Otherwise, returns a tensor of shape `[A1, ... An, D + 1]` that contains the
    non-zero weights, and a tensor with the indices of the knots, with the type
    tf.int32.

  Raises:
    ValueError: If degree is greater than 4 or num_knots - 1, or less than 0.
    InvalidArgumentError: If positions are not in the right range.
  """
  with tf.compat.v1.name_scope(name, "bspline_knot_weights", [positions]):
    positions = tf.convert_to_tensor(value=positions)

    if degree > 4 or degree < 0:
      raise ValueError("Degree should be between 0 and 4.")
    if degree > num_knots - 1:
      raise ValueError("Degree cannot be >= number of knots.")
    if cyclical:
      positions = asserts.assert_all_in_range(positions, 0.0, float(num_knots))
    else:
      positions = asserts.assert_all_in_range(positions, 0.0,
                                              float(num_knots - degree))

    all_basis_functions = {
        # Maps valid degrees to functions.
        Degree.CONSTANT: _constant,
        Degree.LINEAR: _linear,
        Degree.QUADRATIC: _quadratic,
        Degree.CUBIC: _cubic,
        Degree.QUARTIC: _quartic
    }
    basis_functions = all_basis_functions[degree]

    if not cyclical and num_knots - degree == 1:
      # In this case all weights are non-zero and we can just return them.
      if not sparse_mode:
        return basis_functions(positions)
      else:
        shift = tf.zeros_like(positions, dtype=tf.int32)
        return basis_functions(positions), shift

    # shape_batch = positions.shape.as_list()
    shape_batch = tf.shape(input=positions)
    positions = tf.reshape(positions, shape=(-1,))

    # Calculate the nonzero weights from the decimal parts of positions.
    shift = tf.floor(positions)
    sparse_weights = basis_functions(positions - shift)
    shift = tf.cast(shift, tf.int32)

    if sparse_mode:
      # Returns just the weights and the shift amounts, so that tf.gather_nd on
      # the knots can be used to sparsely activate knots if needed.
      shape_weights = tf.concat(
          (shape_batch, tf.constant((degree + 1,), dtype=tf.int32)), axis=0)
      sparse_weights = tf.reshape(sparse_weights, shape=shape_weights)
      shift = tf.reshape(shift, shape=shape_batch)
      return sparse_weights, shift

    num_positions = tf.size(input=positions)
    ind_row, ind_col = tf.meshgrid(
        tf.range(num_positions, dtype=tf.int32),
        tf.range(degree + 1, dtype=tf.int32),
        indexing="ij")

    tiled_shifts = tf.reshape(
        tf.tile(tf.expand_dims(shift, axis=-1), multiples=(1, degree + 1)),
        shape=(-1,))
    ind_col = tf.reshape(ind_col, shape=(-1,)) + tiled_shifts
    if cyclical:
      ind_col = tf.math.mod(ind_col, num_knots)
    indices = tf.stack((tf.reshape(ind_row, shape=(-1,)), ind_col), axis=-1)
    shape_indices = tf.concat((tf.reshape(
        num_positions, shape=(1,)), tf.constant(
            (degree + 1, 2), dtype=tf.int32)),
                              axis=0)
    indices = tf.reshape(indices, shape=shape_indices)
    shape_scatter = tf.concat((tf.reshape(
        num_positions, shape=(1,)), tf.constant((num_knots,), dtype=tf.int32)),
                              axis=0)
    weights = tf.scatter_nd(indices, sparse_weights, shape_scatter)
    shape_weights = tf.concat(
        (shape_batch, tf.constant((num_knots,), dtype=tf.int32)), axis=0)
    return tf.reshape(weights, shape=shape_weights)