예제 #1
0
  def _parameter_control_dependencies(self, is_init):
    if not self.validate_args:
      return []
    assertions = []

    if self._num_steps is not None:
      if is_init != tensor_util.is_ref(self._num_steps):
        assertions.append(assert_util.assert_rank(
            self._num_steps, 0,
            message='Argument `num_steps` must be a scalar'))
        assertions.append(assert_util.assert_positive(
            self._num_steps, message='Argument `num_steps` must be positive'))

    return assertions
예제 #2
0
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []

        assertions = []

        if is_init != tensor_util.is_ref(self.total_count):
            total_count = tf.convert_to_tensor(self.total_count)
            msg1 = 'Argument `total_count` must be non-negative.'
            msg2 = 'Argument `total_count` cannot contain fractional components.'
            assertions += [
                assert_util.assert_non_negative(total_count, message=msg1),
                distribution_util.assert_integer_form(total_count,
                                                      message=msg2),
            ]

        for concentration in [self.concentration1, self.concentration0]:
            if is_init != tensor_util.is_ref(concentration):
                assertions.append(
                    assert_util.assert_positive(
                        concentration,
                        message='Concentration parameter must be positive.'))
        return assertions
예제 #3
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     low = None
     high = None
     if is_init != tensor_util.is_ref(self.low):
         low = tf.convert_to_tensor(self.low)
         high = tf.convert_to_tensor(self.high)
         assertions.append(
             assert_util.assert_less(
                 low,
                 high,
                 message='uniform not defined when `low` >= `high`.'))
     if is_init != tensor_util.is_ref(self.high):
         low = tf.convert_to_tensor(self.low) if low is None else low
         high = tf.convert_to_tensor(self.high) if high is None else high
         assertions.append(
             assert_util.assert_less(
                 low,
                 high,
                 message='uniform not defined when `low` >= `high`.'))
     return assertions
예제 #4
0
  def _parameter_control_dependencies(self, is_init):
    if not self.validate_args:
      return []

    assertions = []
    if is_init != tensor_util.is_ref(self._temperature):
      msg1 = 'Argument `temperature` must be positive.'
      temperature = tf.convert_to_tensor(self._temperature)
      assertions.append(assert_util.assert_positive(temperature, message=msg1))

    if self._probs is not None:
      if is_init != tensor_util.is_ref(self._probs):
        probs = tf.convert_to_tensor(self._probs)
        one = tf.constant(1., probs.dtype)
        assertions.extend([
            assert_util.assert_non_negative(
                probs, message='Argument `probs` has components less than 0.'),
            assert_util.assert_less_equal(
                probs, one,
                message='Argument `probs` has components greater than 1.')
        ])

    return assertions
예제 #5
0
    def testVariableParametersArePreserved(self, process_name, data):
        # Check that the process passes Variables through to the accessor
        # properties (without converting them to Tensor or anything like that).
        process = data.draw(
            stochastic_processes(process_name, enable_vars=True))
        self.evaluate([var.initializer for var in process.variables])

        for k, v in six.iteritems(process.parameters):
            if not tensor_util.is_ref(v):
                continue
            self.assertIs(
                getattr(process, k), v,
                'Parameter equivalance assertion failed for parameter `{}`'.
                format(k))
예제 #6
0
  def _parameter_control_dependencies(self, is_init):
    if not self.validate_args:
      return []
    assertions = []
    for param_name, param in dict(
        concentration=self.concentration,
        mixing_concentration=self.mixing_concentration,
        mixing_rate=self.mixing_rate).items():

      if is_init != tensor_util.is_ref(param):
        assertions.append(assert_util.assert_positive(
            param,
            message='Argument `{}` must be positive.'.format(param_name)))
    return assertions
예제 #7
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   for c in [
       self.concentration0_numerator,
       self.concentration1_numerator,
       self.concentration0_denominator,
       self.concentration1_denominator]:
     if is_init != tensor_util.is_ref(c):
       assertions.append(assert_util.assert_positive(
           c,
           message='`concentration` must be positive.'))
   return assertions
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    if is_init:
      try:
        self._batch_shape()
      except ValueError:
        raise ValueError(
            'Arguments `loc`, `scale`, and `rate` must have compatible shapes; '
            'loc.shape={}, scale.shape={}, rate.shape={}.'.format(
                self.loc.shape, self.scale.shape, self.rate.shape))
      # We don't bother checking the shapes in the dynamic case because
      # all member functions access both arguments anyway.

    if is_init != tensor_util.is_ref(self.scale):
      assertions.append(assert_util.assert_positive(
          self.scale, message='Argument `scale` must be positive.'))

    if is_init != tensor_util.is_ref(self.rate):
      assertions.append(assert_util.assert_positive(
          self.rate, message='Argument `rate` must be positive.'))

    return assertions
예제 #9
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self.mass):
         assertions.append(
             assert_util.assert_positive(
                 self.mass, message='Argument `mass` must be positive.'))
     if is_init != tensor_util.is_ref(self.width):
         assertions.append(
             assert_util.assert_positive(
                 self.width, message='Argument `width` must be positive.'))
     if is_init != tensor_util.is_ref(self.smin):
         assertions.append(
             assert_util.assert_non_negative(
                 self.smin,
                 message='Argument `smin` must be positive or zero.'))
     if is_init != tensor_util.is_ref(self.smax):
         assertions.append(
             assert_util.assert_greater(
                 self.smax,
                 self.smin,
                 message='Argument `smax` must be larger than `smin`.'))
     return assertions
예제 #10
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_ref(self._batch_shape_parameter):
         assertions.append(
             assert_util.assert_rank(
                 self._batch_shape_parameter,
                 1,
                 message='Batch shape must be a vector.'))
         assertions.append(
             assert_util.assert_non_negative(
                 self._batch_shape_parameter,
                 message='Shape elements must be >-1.'))
     return assertions
예제 #11
0
  def _parameter_control_dependencies(self, is_init):

    assertions = super(Wishart, self)._parameter_control_dependencies(is_init)

    if not self.validate_args:
      assert not assertions
      return []

    if self._scale_full is None:
      if is_init != tensor_util.is_ref(self._scale_tril):
        shape = prefer_static.shape(self._scale_tril)
        assertions.extend(
            [assert_util.assert_positive(
                tf.linalg.diag_part(self._scale_tril),
                message='`scale_tril` must be positive definite.'),
             assert_util.assert_equal(
                 shape[-1],
                 shape[-2],
                 message='`scale_tril` must be square.')]
            )
    else:
      if is_init != tensor_util.is_ref(self._scale_full):
        assertions.append(distribution_util.assert_symmetric(self._scale_full))
    return assertions
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []
        assertions = []
        tailweight_is_ref = tensor_util.is_ref(self.tailweight)
        tailweight = tf.convert_to_tensor(self.tailweight)
        if (is_init != tailweight_is_ref
                and is_init != tensor_util.is_ref(self.skewness)):
            assertions.append(
                assert_util.assert_less(
                    tf.math.abs(self.skewness),
                    tailweight,
                    message='Expect `tailweight > |skewness|`'))
        if is_init != tensor_util.is_ref(self.scale):
            assertions.append(
                assert_util.assert_positive(
                    self.scale, message='Argument `scale` must be positive.'))
        if is_init != tailweight_is_ref:
            assertions.append(
                assert_util.assert_positive(
                    tailweight,
                    message='Argument `tailweight` must be positive.'))

        return assertions
예제 #13
0
def assert_no_none_grad(bijector, method, wrt_vars, grads):
    for var, grad in zip(wrt_vars, grads):
        if 'log_det_jacobian' in method:
            if tensor_util.is_ref(var):
                # We check tensor_util.is_ref to accounts for xs/ys being in vars.
                var_name = var.name.rstrip('_0123456789:')
            else:
                var_name = '[arg]'
            to_check = bijector.bijector if is_invert(bijector) else bijector
            if var_name in NO_LDJ_GRADS_EXPECTED.get(
                    type(to_check).__name__, ()):
                continue
        if grad is None:
            raise AssertionError(
                'Missing `{}` -> {} grad for bijector {}'.format(
                    method, var, bijector))
예제 #14
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if self._probs is not None:
         if is_init != tensor_util.is_ref(self._probs):
             probs = tf.convert_to_tensor(self._probs)
             assertions.append(
                 assert_util.assert_positive(
                     probs, message='Argument `probs` must be positive.'))
             assertions.append(
                 assert_util.assert_less_equal(
                     probs,
                     dtype_util.as_numpy_dtype(self.dtype)(1.),
                     message=
                     'Argument `probs` must be less than or equal to 1.'))
     return assertions
예제 #15
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []
        sample_shape = None  # Memoize concretization.

        # Check valid shape.
        ndims_ = tensorshape_util.rank(self.sample_shape.shape)
        if is_init != (ndims_ is None):
            msg = 'Argument `sample_shape` must be either a scalar or a vector.'
            if ndims_ is not None:
                if ndims_ > 1:
                    raise ValueError(msg)
            elif self.validate_args:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                assertions.append(
                    assert_util.assert_less(tf.rank(sample_shape),
                                            2,
                                            message=msg))

        # Check valid dtype.
        if is_init:  # No xor check because `dtype` cannot change.
            dtype_ = self.sample_shape.dtype
            if dtype_ is None:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                dtype_ = sample_shape.dtype
            if dtype_util.base_dtype(dtype_) not in {tf.int32, tf.int64}:
                raise TypeError(
                    'Argument `sample_shape` must be integer type; '
                    'saw {}.'.format(dtype_util.name(dtype_)))

        # Check valid "value".
        if is_init != tensor_util.is_ref(self.sample_shape):
            sample_shape_ = tf.get_static_value(self.sample_shape)
            msg = 'Argument `sample_shape` must have non-negative values.'
            if sample_shape_ is not None:
                if np.any(np.array(sample_shape_) < 0):
                    raise ValueError('{} Saw: {}'.format(msg, sample_shape_))
            elif self.validate_args:
                if sample_shape is None:
                    sample_shape = tf.convert_to_tensor(self.sample_shape)
                assertions.append(
                    assert_util.assert_greater(sample_shape, -1, message=msg))

        return assertions
예제 #16
0
    def _parameter_control_dependencies(self, is_init):
        """Checks the validity of the concentration parameter."""
        assertions = []

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:
            if not dtype_util.is_floating(self.concentration.dtype):
                raise TypeError('Argument `concentration` must be float type.')

            msg = 'Argument `concentration` must have rank at least 1.'
            ndims = tensorshape_util.rank(self.concentration.shape)
            if ndims is not None:
                if ndims < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                assertions.append(
                    assert_util.assert_rank_at_least(self.concentration,
                                                     1,
                                                     message=msg))

            msg = 'Argument `concentration` must have `event_size` at least 2.'
            event_size = tf.compat.dimension_value(
                self.concentration.shape[-1])
            if event_size is not None:
                if event_size < 2:
                    raise ValueError(msg)
            elif self.validate_args:
                assertions.append(
                    assert_util.assert_less(1,
                                            tf.shape(self.concentration)[-1],
                                            message=msg))

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if is_init != tensor_util.is_ref(self.concentration):
            assertions.append(
                assert_util.assert_positive(
                    self.concentration,
                    message='Argument `concentration` must be positive.'))

        return assertions
예제 #17
0
    def _parameter_control_dependencies(self, is_init):
        assertions = super(WishartTriL,
                           self)._parameter_control_dependencies(is_init)

        if not self.validate_args:
            assert not assertions
            return []

        if is_init != tensor_util.is_ref(self._scale_tril):
            shape = ps.shape(self._scale_tril)
            assertions.extend([
                assert_util.assert_positive(
                    tf.linalg.diag_part(self._scale_tril),
                    message='`scale_tril` must be positive definite.'),
                assert_util.assert_equal(
                    shape[-1],
                    shape[-2],
                    message='`scale_tril` must be square.')
            ])
        return assertions
def assert_no_none_grad(bijector, method, wrt_vars, grads):
  for var, grad in zip(wrt_vars, grads):
    expect_grad = var.dtype not in (tf.int32, tf.int64)
    if 'log_det_jacobian' in method:
      if tensor_util.is_ref(var):
        # We check tensor_util.is_ref to account for xs/ys being in vars.
        var_name = var.name.rstrip('_0123456789:').split('/')[-1]
      else:
        var_name = '[arg]'
      to_check = bijector.bijector if is_invert(bijector) else bijector
      to_check_method = INVERT_LDJ[method] if is_invert(bijector) else method
      if var_name == '[arg]' and bijector.is_constant_jacobian:
        expect_grad = False
      exempt_var_method = NO_LDJ_GRADS_EXPECTED.get(type(to_check).__name__, {})
      if to_check_method in exempt_var_method.get(var_name, ()):
        expect_grad = False

    if expect_grad != (grad is not None):
      raise AssertionError('{} `{}` -> {} grad for bijector {}'.format(
          'Missing' if expect_grad else 'Unexpected', method, var, bijector))
예제 #19
0
    def _parameter_control_dependencies(self, is_init):
        if not self.validate_args:
            return []
        assertions = []
        ok_to_check = lambda x: (  # pylint:disable=g-long-lambda
            x is not None) and (is_init != tensor_util.is_ref(x))

        bias_variance = self.bias_variance
        slope_variance = self.slope_variance

        if ok_to_check(self.exponent):
            exponent = tf.convert_to_tensor(self.exponent)
            assertions.append(
                assert_util.assert_positive(
                    exponent, message='`exponent` must be positive.'))
            from tensorflow_probability.python.internal import distribution_util  # pylint: disable=g-import-not-at-top
            assertions.append(
                distribution_util.assert_integer_form(
                    exponent, message='`exponent` must be an integer.'))
        if ok_to_check(self.bias_variance):
            bias_variance = tf.convert_to_tensor(self.bias_variance)
            assertions.append(
                assert_util.assert_non_negative(
                    bias_variance,
                    message='`bias_variance` must be non-negative.'))
        if ok_to_check(self.slope_variance):
            slope_variance = tf.convert_to_tensor(self.slope_variance)
            assertions.append(
                assert_util.assert_non_negative(
                    slope_variance,
                    message='`slope_variance` must be non-negative.'))

        if (ok_to_check(self.bias_variance)
                and ok_to_check(self.slope_variance)):
            assertions.append(
                assert_util.assert_positive(
                    tf.math.abs(slope_variance) + tf.math.abs(bias_variance),
                    message=('`slope_variance` and `bias_variance` '
                             'can not both be zero.')))

        return assertions
def base_kernels(draw,
                 kernel_name=None,
                 batch_shape=None,
                 event_dim=None,
                 feature_dim=None,
                 feature_ndims=None,
                 enable_vars=False):
    if kernel_name is None:
        kernel_name = draw(hps.sampled_from(sorted(INSTANTIABLE_BASE_KERNELS)))
    if batch_shape is None:
        batch_shape = draw(tfp_hps.shapes())
    if event_dim is None:
        event_dim = draw(hps.integers(min_value=2, max_value=6))
    if feature_dim is None:
        feature_dim = draw(hps.integers(min_value=2, max_value=6))
    if feature_ndims is None:
        feature_ndims = draw(hps.integers(min_value=2, max_value=6))

    kernel_params = draw(
        broadcasting_params(kernel_name,
                            batch_shape,
                            event_dim=event_dim,
                            enable_vars=enable_vars))
    kernel_variable_names = [
        k for k in kernel_params if tensor_util.is_ref(kernel_params[k])
    ]
    hp.note('Forming kernel {} with constrained parameters {}'.format(
        kernel_name, kernel_params))
    ctor = getattr(tfpk, kernel_name)
    result_kernel = ctor(validate_args=True,
                         feature_ndims=feature_ndims,
                         **kernel_params)
    if batch_shape != result_kernel.batch_shape:
        msg = ('Kernel strategy generated a bad batch shape '
               'for {}, should have been {}.').format(result_kernel,
                                                      batch_shape)
        raise AssertionError(msg)
    return result_kernel, kernel_variable_names
예제 #21
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     # Avoid computing intermediates needed to construct the assertions.
     return []
   assertions = []
   if is_init != tensor_util.is_ref(self._batch_shape_unexpanded):
     implicit_dim_mask = prefer_static.equal(self._batch_shape_unexpanded, -1)
     assertions.append(assert_util.assert_rank(
         self._batch_shape_unexpanded, 1,
         message='New shape must be a vector.'))
     assertions.append(assert_util.assert_less_equal(
         tf.math.count_nonzero(implicit_dim_mask, dtype=tf.int32), 1,
         message='At most one dimension can be unknown.'))
     assertions.append(assert_util.assert_non_negative(
         self._batch_shape_unexpanded + 1,
         message='Shape elements must be >=-1.'))
     # Check that the old and new shapes are the same size.
     expanded_new_shape, original_size = self._calculate_new_shape()
     new_size = prefer_static.reduce_prod(expanded_new_shape)
     assertions.append(assert_util.assert_equal(
         new_size, tf.cast(original_size, new_size.dtype),
         message='Shape sizes do not match.'))
   return assertions
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:

            if not dtype_util.is_floating(self.cutpoints.dtype):
                raise TypeError(
                    'Argument `cutpoints` must having floating type.')

            if not dtype_util.is_floating(self.loc.dtype):
                raise TypeError('Argument `loc` must having floating type.')

            cutpoint_dims = tensorshape_util.rank(self.cutpoints.shape)
            msg = 'Argument `cutpoints` must have rank at least 1.'
            if cutpoint_dims is not None:
                if cutpoint_dims < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                cutpoints = tf.convert_to_tensor(self.cutpoints)
                assertions.append(
                    assert_util.assert_rank_at_least(cutpoints, 1,
                                                     message=msg))

        if not self.validate_args:
            return []

        if is_init != tensor_util.is_ref(self.cutpoints):
            cutpoints = tf.convert_to_tensor(self.cutpoints)
            assertions.append(
                distribution_util.assert_nondecreasing(
                    cutpoints,
                    message='Argument `cutpoints` must be non-decreasing.'))

        return assertions
예제 #23
0
  def _parameter_control_dependencies(self, is_init):
    assertions = []

    if is_init != tensor_util.is_ref(self.permutation):
      if not dtype_util.is_integer(self.permutation.dtype):
        raise TypeError('permutation.dtype ({}) should be `int`-like.'.format(
            dtype_util.name(self.permutation.dtype)))

      p = tf.get_static_value(self.permutation)
      if p is not None:
        if set(p) != set(np.arange(p.size)):
          raise ValueError('Permutation over `d` must contain exactly one of '
                           'each of `{0, 1, ..., d}`.')

      if self.validate_args:
        p = tf.sort(self.permutation, axis=-1)
        assertions.append(
            assert_util.assert_equal(
                p,
                tf.range(tf.shape(p)[-1]),
                message=('Permutation over `d` must contain exactly one of '
                         'each of `{0, 1, ..., d}`.')))

    return assertions
예제 #24
0
  def _parameter_control_dependencies(self, is_init):
    if not self.validate_args:
      return []
    low = tf.convert_to_tensor(self.low)
    high = tf.convert_to_tensor(self.high)
    peak = tf.convert_to_tensor(self.peak)
    assertions = []
    if (is_init != tensor_util.is_ref(self.low) and
        is_init != tensor_util.is_ref(self.high)):
      assertions.append(assert_util.assert_less(
          low, high, message='triangular not defined when low >= high.'))
    if (is_init != tensor_util.is_ref(self.low) and
        is_init != tensor_util.is_ref(self.peak)):
      assertions.append(
          assert_util.assert_less_equal(
              low, peak, message='triangular not defined when low > peak.'))
    if (is_init != tensor_util.is_ref(self.high) and
        is_init != tensor_util.is_ref(self.peak)):
      assertions.append(
          assert_util.assert_less_equal(
              peak, high, message='triangular not defined when peak > high.'))

    return assertions
예제 #25
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     low = None
     high = None
     if is_init != tensor_util.is_ref(self.low):
         low = tf.convert_to_tensor(self.low)
         assertions.append(
             assert_util.assert_finite(low, message='`low` is not finite'))
     if is_init != tensor_util.is_ref(self.high):
         high = tf.convert_to_tensor(self.high)
         assertions.append(
             assert_util.assert_finite(high,
                                       message='`high` is not finite'))
     if is_init != tensor_util.is_ref(self.loc):
         assertions.append(
             assert_util.assert_finite(self.loc,
                                       message='`loc` is not finite'))
     if is_init != tensor_util.is_ref(self.scale):
         scale = tf.convert_to_tensor(self.scale)
         assertions.extend([
             assert_util.assert_positive(
                 scale, message='`scale` must be positive'),
             assert_util.assert_finite(scale,
                                       message='`scale` is not finite'),
         ])
     if (is_init != tensor_util.is_ref(self.low)
             or is_init != tensor_util.is_ref(self.high)):
         low = tf.convert_to_tensor(self.low) if low is None else low
         high = tf.convert_to_tensor(self.high) if high is None else high
         assertions.append(
             assert_util.assert_greater(
                 high,
                 low,
                 message='TruncatedCauchy not defined when `low >= high`.'))
     return assertions
  def testDistribution(self, dist_name, data):
    seed = test_util.test_seed()
    # Explicitly draw event_dim here to avoid relying on _params_event_ndims
    # later, so this test can support distributions that do not implement the
    # slicing protocol.
    event_dim = data.draw(hps.integers(min_value=2, max_value=6))
    dist = data.draw(dhps.distributions(
        dist_name=dist_name, event_dim=event_dim, enable_vars=True))
    batch_shape = dist.batch_shape
    batch_shape2 = data.draw(tfp_hps.broadcast_compatible_shape(batch_shape))
    dist2 = data.draw(
        dhps.distributions(
            dist_name=dist_name,
            batch_shape=batch_shape2,
            event_dim=event_dim,
            enable_vars=True))
    self.evaluate([var.initializer for var in dist.variables])

    # Check that the distribution passes Variables through to the accessor
    # properties (without converting them to Tensor or anything like that).
    for k, v in six.iteritems(dist.parameters):
      if not tensor_util.is_ref(v):
        continue
      self.assertIs(getattr(dist, k), v)

    # Check that standard statistics do not read distribution parameters more
    # than twice (once in the stat itself and up to once in any validation
    # assertions).
    max_permissible = 2 + extra_tensor_conversions_allowed(dist)
    for stat in sorted(data.draw(
        hps.sets(
            hps.one_of(
                map(hps.just, [
                    'covariance', 'entropy', 'mean', 'mode', 'stddev',
                    'variance'
                ])),
            min_size=3,
            max_size=3))):
      hp.note('Testing excessive var usage in {}.{}'.format(dist_name, stat))
      try:
        with tfp_hps.assert_no_excessive_var_usage(
            'statistic `{}` of `{}`'.format(stat, dist),
            max_permissible=max_permissible):
          getattr(dist, stat)()

      except NotImplementedError:
        pass

    # Check that `sample` doesn't read distribution parameters more than twice,
    # and that it produces non-None gradients (if the distribution is fully
    # reparameterized).
    with tf.GradientTape() as tape:
      # TDs do bijector assertions twice (once by distribution.sample, and once
      # by bijector.forward).
      max_permissible = 2 + extra_tensor_conversions_allowed(dist)
      with tfp_hps.assert_no_excessive_var_usage(
          'method `sample` of `{}`'.format(dist),
          max_permissible=max_permissible):
        sample = dist.sample(seed=seed)
    if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
      grads = tape.gradient(sample, dist.variables)
      for grad, var in zip(grads, dist.variables):
        var_name = var.name.rstrip('_0123456789:')
        if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()):
          continue
        if grad is None:
          raise AssertionError(
              'Missing sample -> {} grad for distribution {}'.format(
                  var_name, dist_name))

    # Turn off validations, since TODO(b/129271256) log_prob can choke on dist's
    # own samples.  Also, to relax conversion counts for KL (might do >2 w/
    # validate_args).
    dist = dist.copy(validate_args=False)
    dist2 = dist2.copy(validate_args=False)

    # Test that KL divergence reads distribution parameters at most once, and
    # that is produces non-None gradients.
    try:
      for d1, d2 in (dist, dist2), (dist2, dist):
        with tf.GradientTape() as tape:
          with tfp_hps.assert_no_excessive_var_usage(
              '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))'.format(
                  d1, d1.variables, d2, d2.variables),
              max_permissible=1):  # No validation => 1 convert per var.
            kl = d1.kl_divergence(d2)
        wrt_vars = list(d1.variables) + list(d2.variables)
        grads = tape.gradient(kl, wrt_vars)
        for grad, var in zip(grads, wrt_vars):
          if grad is None and dist_name not in NO_KL_PARAM_GRADS:
            raise AssertionError('Missing KL({} || {}) -> {} grad:\n'
                                 '{} vars: {}\n{} vars: {}'.format(
                                     d1, d2, var, d1, d1.variables, d2,
                                     d2.variables))
    except NotImplementedError:
      pass

    # Test that log_prob produces non-None gradients, except for distributions
    # on the NO_LOG_PROB_PARAM_GRADS blacklist.
    if dist_name not in NO_LOG_PROB_PARAM_GRADS:
      with tf.GradientTape() as tape:
        lp = dist.log_prob(tf.stop_gradient(sample))
      grads = tape.gradient(lp, dist.variables)
      for grad, var in zip(grads, dist.variables):
        if grad is None:
          raise AssertionError(
              'Missing log_prob -> {} grad for distribution {}'.format(
                  var, dist_name))

    # Test that all forms of probability evaluation avoid reading distribution
    # parameters more than once.
    for evaluative in sorted(data.draw(
        hps.sets(
            hps.one_of(
                map(hps.just, [
                    'log_prob', 'prob', 'log_cdf', 'cdf',
                    'log_survival_function', 'survival_function'
                ])),
            min_size=3,
            max_size=3))):
      hp.note('Testing excessive var usage in {}.{}'.format(
          dist_name, evaluative))
      try:
        # No validation => 1 convert. But for TD we allow 2:
        # dist.log_prob(bijector.inverse(samp)) + bijector.ildj(samp)
        max_permissible = 2 + extra_tensor_conversions_allowed(dist)
        with tfp_hps.assert_no_excessive_var_usage(
            'evaluative `{}` of `{}`'.format(evaluative, dist),
            max_permissible=max_permissible):
          getattr(dist, evaluative)(sample)
      except NotImplementedError:
        pass
    def _maybe_validate_shape_override(self, override_shape, base_is_scalar_fn,
                                       static_base_shape, is_init):
        """Helper which ensures override batch/event_shape are valid."""

        assertions = []
        concretized_shape = None

        # Check valid dtype
        if is_init:  # No xor check because `dtype` cannot change.
            dtype_ = override_shape.dtype
            if dtype_ is None:
                if concretized_shape is None:
                    concretized_shape = tf.convert_to_tensor(override_shape)
                dtype_ = concretized_shape.dtype
            if dtype_util.base_dtype(dtype_) not in {tf.int32, tf.int64}:
                raise TypeError('Shape override must be integer type; '
                                'saw {}.'.format(dtype_util.name(dtype_)))

        # Check non-negative elements
        if is_init != tensor_util.is_ref(override_shape):
            override_shape_ = tf.get_static_value(override_shape)
            msg = 'Shape override must have non-negative elements.'
            if override_shape_ is not None:
                if np.any(np.array(override_shape_) < 0):
                    raise ValueError('{} Saw: {}'.format(msg, override_shape_))
            elif self.validate_args:
                if concretized_shape is None:
                    concretized_shape = tf.convert_to_tensor(override_shape)
                assertions.append(
                    assert_util.assert_non_negative(concretized_shape,
                                                    message=msg))

        # Check valid shape
        override_ndims_ = tensorshape_util.rank(override_shape.shape)
        if is_init != (override_ndims_ is None):
            msg = 'Shape override must be a vector.'
            if override_ndims_ is not None:
                if override_ndims_ != 1:
                    raise ValueError(msg)
            elif self.validate_args:
                if concretized_shape is None:
                    concretized_shape = tf.convert_to_tensor(override_shape)
                override_rank = tf.rank(concretized_shape)
                assertions.append(
                    assert_util.assert_equal(override_rank, 1, message=msg))

        static_base_rank = tensorshape_util.rank(static_base_shape)

        # Determine if the override shape is `[]` (static_override_dims == [0]),
        # in which case the base distribution may be nonscalar.
        static_override_dims = tensorshape_util.dims(override_shape.shape)

        if is_init != (static_base_rank is None
                       or static_override_dims is None):
            msg = 'Base distribution is not scalar.'
            if static_base_rank is not None and static_override_dims is not None:
                if static_base_rank != 0 and static_override_dims != [0]:
                    raise ValueError(msg)
            elif self.validate_args:
                if concretized_shape is None:
                    concretized_shape = tf.convert_to_tensor(override_shape)
                override_is_empty = tf.logical_not(
                    self._has_nonzero_rank(concretized_shape))
                assertions.append(
                    assert_util.assert_equal(tf.logical_or(
                        base_is_scalar_fn(), override_is_empty),
                                             True,
                                             message=msg))
        return assertions
예제 #28
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        logits = self._logits
        probs = self._probs
        param, name = (probs, 'probs') if logits is None else (logits,
                                                               'logits')

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:
            if not dtype_util.is_floating(param.dtype):
                raise TypeError(
                    'Argument `{}` must having floating type.'.format(name))

            msg = 'Argument `{}` must have rank at least 1.'.format(name)
            shape_static = tensorshape_util.dims(param.shape)
            if shape_static is not None:
                if len(shape_static) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_rank_at_least(param, 1, message=msg))
                with tf.control_dependencies(assertions):
                    param = tf.identity(param)

            msg1 = 'Argument `{}` must have final dimension >= 1.'.format(name)
            msg2 = 'Argument `{}` must have final dimension <= {}.'.format(
                name, dtype_util.max(tf.int32))
            event_size = shape_static[-1] if shape_static is not None else None
            if event_size is not None:
                if event_size < 1:
                    raise ValueError(msg1)
                if event_size > dtype_util.max(tf.int32):
                    raise ValueError(msg2)
            elif self.validate_args:
                param = tf.convert_to_tensor(param)
                assertions.append(
                    assert_util.assert_greater_equal(tf.shape(param)[-1],
                                                     1,
                                                     message=msg1))
                # NOTE: For now, we leave out a runtime assertion that
                # `tf.shape(param)[-1] <= tf.int32.max`.  An earlier `tf.shape` call
                # will fail before we get to this point.

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if probs is not None:
            probs = param  # reuse tensor conversion from above
            if is_init != tensor_util.is_ref(probs):
                probs = tf.convert_to_tensor(probs)
                one = tf.ones([], dtype=probs.dtype)
                assertions.extend([
                    assert_util.assert_non_negative(probs),
                    assert_util.assert_less_equal(probs, one),
                    assert_util.assert_near(
                        tf.reduce_sum(probs, axis=-1),
                        one,
                        message='Argument `probs` must sum to 1.'),
                ])

        return assertions
예제 #29
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     if is_init != any(tensor_util.is_ref(v) for v in self.scale.variables):
         return [self.scale.assert_non_singular()]
     return []
예제 #30
0
def as_composite(obj):
    """Returns a `CompositeTensor` equivalent to the given object.

  Note that the returned object will have any `Variable`,
  `tfp.util.DeferredTensor`, or `tfp.util.TransformedVariable` references it
  closes over converted to tensors at the time this function is called. The
  type of the returned object will be a subclass of both `CompositeTensor` and
  `type(obj)`.  For this reason, one should be careful about using
  `as_composite()`, especially for `tf.Module` objects.

  For example, when the composite tensor is created even as part of a
  `tf.Module`, it "fixes" the values of the `DeferredTensor` and `tf.Variable`
  objects it uses:

  ```python
  class M(tf.Module):
    def __init__(self):
      self._v = tf.Variable(1.)
      self._d = tfp.distributions.Normal(
        tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10)
      self._dct = tfp.experimental.as_composite(self._d)

    @tf.function
    def mean(self):
      return self._dct.mean()

  m = M()
  m.mean()
  >>> <tf.Tensor: numpy=2.0>
  m._v.assign(2.)  # Doesn't update the CompositeTensor distribution.
  m.mean()
  >>> <tf.Tensor: numpy=2.0>
  ```

  If, however, the creation of the composite is deferred to a method
  call, then the Variable and DeferredTensor will be properly captured
  and respected by the Module and its `SavedModel` (if it is serialized).

  ```python
  class M(tf.Module):
    def __init__(self):
      self._v = tf.Variable(1.)
      self._d = tfp.distributions.Normal(
        tfp.util.DeferredTensor(self._v, lambda v: v + 1), 10)

    @tf.function
    def d(self):
      return tfp.experimental.as_composite(self._d)

  m = M()
  m.d().mean()
  >>> <tf.Tensor: numpy=2.0>
  m._v.assign(2.)
  m.d().mean()
  >>> <tf.Tensor: numpy=3.0>
  ```

  Note: This method is best-effort and based on a heuristic for what the
  tensor parameters are and what the non-tensor parameters are. Things might be
  broken, especially for meta-distributions like `TransformedDistribution` or
  `Independent`. (We try to raise NotImplementedError in such cases.) If you'd
  benefit from better coverage, please file an issue on github or send an email
  to `[email protected]`.

  Args:
    obj: A `tfp.distributions.Distribution`.

  Returns:
    obj: A `tfp.distributions.Distribution` that extends `CompositeTensor`.
  """
    if isinstance(obj, CompositeTensor):
        return obj
    cls = _make_convertible(type(obj))
    kwargs = dict(obj.parameters)

    def mk_err_msg(suffix=''):
        return (
            'Unable to make a CompositeTensor for "{}" of type `{}`. Email '
            '`[email protected]` or file an issue on github if you '
            'would benefit from this working. {}'.format(
                obj, type(obj), suffix))

    try:
        composite_tensor_params = obj._composite_tensor_params  # pylint: disable=protected-access
    except (AttributeError, NotImplementedError):
        composite_tensor_params = ()
    for k in composite_tensor_params:
        # Use dtype inference from ctor.
        if k in kwargs and kwargs[k] is not None:
            v = getattr(obj, k, kwargs[k])
            try:
                kwargs[k] = tf.convert_to_tensor(v, name=k)
            except (ValueError, TypeError) as e:
                kwargs[k] = v
    for k, v in kwargs.items():

        def composite_helper(v):
            # If we have a parameters attribute, then we may be able to convert to
            # a composite tensor by guessing which of the parameters are tensors.  In
            # essence, we duck-type based on this attribute.
            if hasattr(v, 'parameters'):
                return as_composite(v)
            return v

        kwargs[k] = tf.nest.map_structure(composite_helper, v)
        # Unfortunately, tensor_util.is_ref(v) returns true for a
        # tf.linalg.LinearOperator even though that is not ideal behavior.
        if tensor_util.is_ref(v) and not isinstance(v,
                                                    tf.linalg.LinearOperator):
            try:
                kwargs[k] = tf.convert_to_tensor(v, name=k)
            except TypeError as e:
                raise NotImplementedError(
                    mk_err_msg(
                        '(Unable to convert dependent entry \'{}\' of object '
                        '\'{}\': {})'.format(k, obj, str(e))))
    result = cls(**kwargs)
    struct_coder = nested_structure_coder.StructureCoder()
    try:
        struct_coder.encode_structure(result._type_spec)  # pylint: disable=protected-access
    except nested_structure_coder.NotEncodableError as e:
        raise NotImplementedError(
            mk_err_msg('(Unable to serialize: {})'.format(str(e))))
    return result