def __init__(self):
        """Initialize the distribution.

    Load the values, tangents, and x-coordinate scaling of a spline that
    approximates the partition function. The spline was produced by running
    the script in fit_partition_spline.py.
    """
        with util.get_resource_as_file(
                'robust_loss/data/partition_spline.npz') as spline_file:
            with np.load(spline_file, allow_pickle=False) as f:
                self._spline_x_scale = f['x_scale']
                self._spline_values = f['values']
                self._spline_tangents = f['tangents']
  def _load_golden_data(self):
    """Loads golden data: an RGBimage and its CDF9/7 decomposition.

    This golden data was produced by running the code from
    https://www.getreuer.info/projects/wavelet-cdf-97-implementation
    on a test image.

    Returns:
      A tuple containing and image, its decomposition, and its wavelet type.
    """
    with util.get_resource_as_file(
        'robust_loss/data/wavelet_golden.mat') as golden_filename:
      data = scipy.io.loadmat(golden_filename)
    im = np.float32(data['I_color'])
    pyr_true = data['pyr_color'][0, :].tolist()
    for i in range(len(pyr_true) - 1):
      pyr_true[i] = tuple(pyr_true[i].flatten())
    pyr_true = tuple(pyr_true)
    wavelet_type = 'CDF9/7'
    return im, pyr_true, wavelet_type
def log_base_partition_function(alpha):
    r"""Approximate the distribution's log-partition function with a 1D spline.

  Because the partition function (Z(\alpha) in the paper) of the distribution is
  difficult to model analytically, we approximate it with a (transformed) cubic
  hermite spline: Each alpha is pushed through a nonlinearity before being used
  to interpolate into a spline, which allows us to use a relatively small spline
  to accurately model the log partition function over the range of all
  non-negative input values.

  Args:
    alpha: A tensor or scalar of single or double precision floats containing
      the set of alphas for which we would like an approximate log partition
      function. Must be non-negative, as the partition function is undefined
      when alpha < 0.

  Returns:
    An approximation of log(Z(alpha)) accurate to within 1e-6
  """
    float_dtype = alpha.dtype

    # Load the values, tangents, and x-coordinate scaling of a spline that
    # approximates the partition function. This was produced by running
    # the script in fit_partition_spline.py
    with util.get_resource_as_file(
            'robust_loss/data/partition_spline.npz') as spline_file:
        with np.load(spline_file, allow_pickle=False) as f:
            x_scale = tf.cast(f['x_scale'], float_dtype)
            values = tf.cast(f['values'], float_dtype)
            tangents = tf.cast(f['tangents'], float_dtype)

    # The partition function is undefined when `alpha`< 0.
    assert_ops = [tf.Assert(tf.reduce_all(alpha >= 0.), [alpha])]
    with tf.control_dependencies(assert_ops):
        # Transform `alpha` to the form expected by the spline.
        x = partition_spline_curve(alpha)
        # Interpolate into the spline.
        return cubic_spline.interpolate1d(x * x_scale, values, tangents)