コード例 #1
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init and self._is_vector:
            msg = "Argument `loc` must be at least rank 1."
            if tensorshape_util.rank(self.loc.shape) is not None:
                if tensorshape_util.rank(self.loc.shape) < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                assertions.append(
                    assert_util.assert_rank_at_least(self.loc, 1, message=msg))

        if not self.validate_args:
            assert not assertions  # Should never happen
            return []

        if is_init != tensor_util.is_mutable(self.atol):
            assertions.append(
                assert_util.assert_non_negative(
                    self.atol, message="Argument 'atol' must be non-negative"))
        if is_init != tensor_util.is_mutable(self.rtol):
            assertions.append(
                assert_util.assert_non_negative(
                    self.rtol, message="Argument 'rtol' must be non-negative"))
        return assertions
コード例 #2
0
ファイル: gamma.py プロジェクト: slowmoyang/probability
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_mutable(self.concentration):
         assertions.append(
             assert_util.assert_positive(
                 self.concentration,
                 message="Argument `concentration` must be positive."))
     if is_init != tensor_util.is_mutable(self.rate):
         assertions.append(
             assert_util.assert_positive(
                 self.rate, message="Argument `rate` must be positive."))
     return assertions
コード例 #3
0
def maybe_assert_categorical_param_correctness(is_init, validate_args, probs,
                                               logits):
    """Return assertions for `Categorical`-type distributions."""
    assertions = []

    # In init, we can always build shape and dtype checks because
    # we assume shape doesn't change for Variable backed args.
    if is_init:
        x, name = (probs, 'probs') if logits is None else (logits, 'logits')

        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

        msg = 'Argument `{}` must have rank at least 1.'.format(name)
        ndims = tensorshape_util.rank(x.shape)
        if ndims is not None:
            if ndims < 1:
                raise ValueError(msg)
        elif validate_args:
            x = tf.convert_to_tensor(x)
            probs = x if logits is None else None  # Retain tensor conversion.
            logits = x if probs is None else None
            assertions.append(
                assert_util.assert_rank_at_least(x, 1, message=msg))

    if not validate_args:
        assert not assertions  # Should never happen.
        return []

    if logits is not None:
        if is_init != tensor_util.is_mutable(logits):
            logits = tf.convert_to_tensor(logits)
            assertions.extend(
                distribution_util.assert_categorical_event_shape(logits))

    if probs is not None:
        if is_init != tensor_util.is_mutable(probs):
            probs = tf.convert_to_tensor(probs)
            assertions.extend([
                assert_util.assert_non_negative(probs),
                assert_util.assert_near(
                    tf.reduce_sum(probs, axis=-1),
                    np.array(1, dtype=dtype_util.as_numpy_dtype(probs.dtype)),
                    message='Argument `probs` must sum to 1.')
            ])
            assertions.extend(
                distribution_util.assert_categorical_event_shape(probs))

    return assertions
コード例 #4
0
ファイル: bernoulli.py プロジェクト: cschwar1/probability
def maybe_assert_bernoulli_param_correctness(
    is_init, validate_args, probs, logits):
  """Return assertions for `Categorical`-type distributions."""
  if is_init:
    x, name = (probs, 'probs') if logits is None else (logits, 'logits')
    if not dtype_util.is_floating(x.dtype):
      raise TypeError(
          'Argument `{}` must having floating type.'.format(name))

  if not validate_args:
    return []

  assertions = []

  if probs is not None:
    if is_init != tensor_util.is_mutable(probs):
      probs = tf.convert_to_tensor(probs)
      one = tf.constant(1., probs.dtype)
      assertions += [
          assert_util.assert_non_negative(
              probs, message='probs has components less than 0.'),
          assert_util.assert_less_equal(
              probs, one, message='probs has components greater than 1.')
      ]

  return assertions
コード例 #5
0
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   if is_init != tensor_util.is_mutable(self._df):
     assertions.append(assert_util.assert_positive(
         self._df, message='Argument `df` must be positive.'))
   return assertions
コード例 #6
0
 def _parameter_control_dependencies(self, is_init):
   assertions = categorical_lib.maybe_assert_categorical_param_correctness(
       is_init, self.validate_args, self._probs, self._logits)
   if not self.validate_args:
     return assertions
   if is_init != tensor_util.is_mutable(self.total_count):
     assertions.extend(distribution_util.assert_nonnegative_integer_form(
         self.total_count))
   return assertions
コード例 #7
0
 def test_various_types(self):
     self.assertFalse(tensor_util.is_mutable(0.))
     self.assertFalse(tensor_util.is_mutable(FakeModule(0.)))
     self.assertFalse(tensor_util.is_mutable([tf.Variable(0.)]))  # Note!
     self.assertFalse(tensor_util.is_mutable(np.array(0., np.float32)))
     self.assertFalse(tensor_util.is_mutable(tf.constant(0.)))
     self.assertTrue(tensor_util.is_mutable(FakeModule(tf.Variable(0.))))
     self.assertTrue(tensor_util.is_mutable(tf.Variable(0.)))
コード例 #8
0
ファイル: beta.py プロジェクト: slowmoyang/probability
 def _parameter_control_dependencies(self, is_init):
   if not self.validate_args:
     return []
   assertions = []
   for concentration in [self.concentration0, self.concentration1]:
     if is_init != tensor_util.is_mutable(concentration):
       assertions.append(assert_util.assert_positive(
           concentration,
           message="Concentration parameter must be positive."))
   return assertions
コード例 #9
0
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_mutable(self.df):
         assertions.append(
             assert_util.assert_greater(
                 self.df,
                 dtype_util.as_numpy_dtype(self.df.dtype)(2.),
                 message='`df` must be greater than 2.'))
     return assertions
コード例 #10
0
ファイル: logistic.py プロジェクト: slowmoyang/probability
 def _parameter_control_dependencies(self, is_init):
     if is_init:
         dtype_util.assert_same_float_dtype([self.loc, self.scale])
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_mutable(self._scale):
         assertions.append(
             assert_util.assert_positive(
                 self._scale, message='Argument `scale` must be positive.'))
     return assertions
コード例 #11
0
ファイル: affine_scalar.py プロジェクト: ywangV/probability
 def _parameter_control_dependencies(self, is_init):
     if not self.validate_args:
         return []
     assertions = []
     if is_init != tensor_util.is_mutable(self.scale):
         assertions.append(
             assert_util.assert_none_equal(
                 self.scale,
                 tf.zeros([], dtype=self._scale.dtype),
                 message="Argument `scale` must be non-zero."))
     return assertions
コード例 #12
0
 def _parameter_control_dependencies(self, is_init):
   if is_init and not dtype_util.is_integer(self.axis.dtype):
     raise TypeError('Argument `axis` is not an `int` type.')
   if not self.validate_args:
     return []
   assertions = []
   if is_init != tensor_util.is_mutable(self.axis):
     assertions.append(assert_util.assert_negative(
         self.axis,
         message='Argument `axis` must be negative.'))
   return assertions
コード例 #13
0
def assert_no_none_grad(bijector, method, wrt_vars, grads):
    for var, grad in zip(wrt_vars, grads):
        if 'log_det_jacobian' in method:
            if tensor_util.is_mutable(var):
                # We check tensor_util.is_mutable to accounts for xs/ys being in vars.
                var_name = var.name.rstrip('_0123456789:')
            else:
                var_name = '[arg]'
            to_check = bijector.bijector if is_invert(bijector) else bijector
            if var_name in NO_LDJ_GRADS_EXPECTED.get(
                    type(to_check).__name__, ()):
                continue
        if grad is None:
            raise AssertionError(
                'Missing `{}` -> {} grad for bijector {}'.format(
                    method, var, bijector))
コード例 #14
0
    def _parameter_control_dependencies(self, is_init):
        """Checks the validity of the concentration parameter."""
        assertions = []

        # In init, we can always build shape and dtype checks because
        # we assume shape doesn't change for Variable backed args.
        if is_init:
            if not dtype_util.is_floating(self.concentration.dtype):
                raise TypeError('Argument `concentration` must be float type.')

            msg = 'Argument `concentration` must have rank at least 1.'
            ndims = tensorshape_util.rank(self.concentration.shape)
            if ndims is not None:
                if ndims < 1:
                    raise ValueError(msg)
            elif self.validate_args:
                assertions.append(
                    assert_util.assert_rank_at_least(self.concentration,
                                                     1,
                                                     message=msg))

            msg = 'Argument `concentration` must have `event_size` at least 2.'
            event_size = tf.compat.dimension_value(
                self.concentration.shape[-1])
            if event_size is not None:
                if event_size < 2:
                    raise ValueError(msg)
            elif self.validate_args:
                assertions.append(
                    assert_util.assert_less(1,
                                            tf.shape(self.concentration)[-1],
                                            message=msg))

        if not self.validate_args:
            assert not assertions  # Should never happen.
            return []

        if is_init != tensor_util.is_mutable(self.concentration):
            assertions.append(
                assert_util.assert_positive(
                    self.concentration,
                    message='Argument `concentration` must be positive.'))

        return assertions
コード例 #15
0
    def testDistribution(self, dist_name, data):
        if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
            return
        tf.compat.v1.set_random_seed(
            data.draw(
                hpnp.arrays(dtype=np.int64,
                            shape=[]).filter(lambda x: x != 0)))
        dist, batch_shape = data.draw(
            distributions(dist_name=dist_name, enable_vars=True))
        del batch_shape
        logging.info(
            'distribution: %s; parameters used: %s', dist,
            [k for k, v in six.iteritems(dist.parameters) if v is not None])
        self.evaluate([var.initializer for var in dist.variables])
        for k, v in six.iteritems(dist.parameters):
            if not tensor_util.is_mutable(v):
                continue
            try:
                self.assertIs(getattr(dist, k), v)
            except AssertionError as e:
                raise AssertionError(
                    'No attr found for parameter {} of distribution {}: \n{}'.
                    format(k, dist_name, e))
            stat = data.draw(
                hps.one_of(
                    map(hps.just,
                        ['mean', 'mode', 'variance', 'covariance', 'entropy'
                         ])))
            try:
                VAR_USAGES.clear()
                getattr(dist, stat)()
                var_nusages = {
                    var: len(usages)
                    for var, usages in VAR_USAGES.items()
                }
                max_permissible = 2  # TODO(jvdillon): Reduce this to 1.
                if any(
                        len(usages) > max_permissible
                        for usages in VAR_USAGES.values()):
                    for var, usages in six.iteritems(VAR_USAGES):
                        if len(usages) > max_permissible:
                            print(
                                'While executing statistic `{}` of `{}`, detected {} '
                                'Tensor conversions for `{}`:'.format(
                                    stat, dist, len(usages), var))
                            for i, usage in enumerate(usages):
                                print('Conversion {} of {}:\n{}'.format(
                                    i + 1, len(usages), ''.join(usage)))
                    raise AssertionError(
                        'Excessive tensor conversions detected for {} {}: {}'.
                        format(dist_name, stat, var_nusages))
            except NotImplementedError:
                pass

        if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
            with tf.GradientTape() as tape:
                samp = dist.sample()
            grads = tape.gradient(samp, dist.variables)
            for grad, var in zip(grads, dist.variables):
                if grad is None:
                    raise AssertionError(
                        'Missing sample -> {} grad for distribution {}'.format(
                            var, dist_name))

        if dist_name not in NO_LOG_PROB_PARAM_GRADS:
            # Turn off validations, since log_prob can choke on dist's own samples.
            dist = dist.copy(validate_args=False)
            with tf.GradientTape() as tape:
                lp = dist.log_prob(tf.stop_gradient(dist.sample()))
            grads = tape.gradient(lp, dist.variables)
            for grad, var in zip(grads, dist.variables):
                if grad is None:
                    raise AssertionError(
                        'Missing log_prob -> {} grad for distribution {}'.
                        format(var, dist_name))
コード例 #16
0
  def testDistribution(self, dist_name, data):
    if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
      return
    tf1.set_random_seed(
        data.draw(
            hpnp.arrays(dtype=np.int64, shape=[]).filter(lambda x: x != 0)))
    dist, batch_shape = data.draw(
        distributions(dist_name=dist_name, enable_vars=True))
    batch_shape2 = data.draw(tfp_hps.broadcast_compatible_shape(batch_shape))
    dist2, _ = data.draw(
        distributions(
            dist_name=dist_name,
            batch_shape=batch_shape2,
            event_dim=get_event_dim(dist),
            enable_vars=True))
    del batch_shape
    logging.info(
        'distribution: %s; parameters used: %s', dist,
        [k for k, v in six.iteritems(dist.parameters) if v is not None])
    self.evaluate([var.initializer for var in dist.variables])
    for k, v in six.iteritems(dist.parameters):
      if not tensor_util.is_mutable(v):
        continue
      try:
        self.assertIs(getattr(dist, k), v)
      except AssertionError as e:
        raise AssertionError(
            'No attr found for parameter {} of distribution {}: \n{}'.format(
                k, dist_name, e))

    for stat in data.draw(
        hps.sets(
            hps.one_of(
                map(hps.just, [
                    'covariance', 'entropy', 'mean', 'mode', 'stddev',
                    'variance'
                ])),
            min_size=3,
            max_size=3)):
      logging.info('%s.%s', dist_name, stat)
      try:
        with tfp_hps.assert_no_excessive_var_usage(
            'statistic `{}` of `{}`'.format(stat, dist)):
          getattr(dist, stat)()

      except NotImplementedError:
        pass

    with tf.GradientTape() as tape:
      with tfp_hps.assert_no_excessive_var_usage(
          'method `sample` of `{}`'.format(dist)):
        sample = dist.sample()
    if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
      grads = tape.gradient(sample, dist.variables)
      for grad, var in zip(grads, dist.variables):
        var_name = var.name.rstrip('_0123456789:')
        if var_name in NO_SAMPLE_PARAM_GRADS.get(dist_name, ()):
          continue
        if grad is None:
          raise AssertionError(
              'Missing sample -> {} grad for distribution {}'.format(
                  var_name, dist_name))

    # Turn off validations, since log_prob can choke on dist's own samples.
    # Also, to relax conversion counts for KL (might do >2 w/ validate_args).
    dist = dist.copy(validate_args=False)
    dist2 = dist2.copy(validate_args=False)

    try:
      for d1, d2 in (dist, dist2), (dist2, dist):
        with tf.GradientTape() as tape:
          with tfp_hps.assert_no_excessive_var_usage(
              '`kl_divergence` of (`{}` (vars {}), `{}` (vars {}))'.format(
                  d1, d1.variables, d2, d2.variables),
              max_permissible=1):  # No validation => 1 convert per var.
            kl = d1.kl_divergence(d2)
        wrt_vars = list(d1.variables) + list(d2.variables)
        grads = tape.gradient(kl, wrt_vars)
        for grad, var in zip(grads, wrt_vars):
          if grad is None and dist_name not in NO_KL_PARAM_GRADS:
            raise AssertionError('Missing KL({} || {}) -> {} grad:\n'
                                 '{} vars: {}\n{} vars: {}'.format(
                                     d1, d2, var, d1, d1.variables, d2,
                                     d2.variables))
    except NotImplementedError:
      pass

    if dist_name not in NO_LOG_PROB_PARAM_GRADS:
      with tf.GradientTape() as tape:
        lp = dist.log_prob(tf.stop_gradient(sample))
      grads = tape.gradient(lp, dist.variables)
      for grad, var in zip(grads, dist.variables):
        if grad is None:
          raise AssertionError(
              'Missing log_prob -> {} grad for distribution {}'.format(
                  var, dist_name))

    for evaluative in data.draw(
        hps.sets(
            hps.one_of(
                map(hps.just, [
                    'log_prob', 'prob', 'log_cdf', 'cdf',
                    'log_survival_function', 'survival_function'
                ])),
            min_size=3,
            max_size=3)):
      logging.info('%s.%s', dist_name, evaluative)
      try:
        with tfp_hps.assert_no_excessive_var_usage(
            'evaluative `{}` of `{}`'.format(evaluative, dist),
            max_permissible=1):  # No validation => 1 convert
          getattr(dist, evaluative)(sample)
      except NotImplementedError:
        pass
コード例 #17
0
    def testDistribution(self, dist_name, data):
        if tf.executing_eagerly() != (FLAGS.tf_mode == 'eager'):
            return
        tf.compat.v1.set_random_seed(
            data.draw(
                hpnp.arrays(dtype=np.int64,
                            shape=[]).filter(lambda x: x != 0)))
        dist, batch_shape = data.draw(
            distributions(dist_name=dist_name, enable_vars=True))
        del batch_shape
        logging.info(
            'distribution: %s; parameters used: %s', dist,
            [k for k, v in six.iteritems(dist.parameters) if v is not None])
        self.evaluate([var.initializer for var in dist.variables])
        for k, v in six.iteritems(dist.parameters):
            if not tensor_util.is_mutable(v):
                continue
            try:
                self.assertIs(getattr(dist, k), v)
            except AssertionError as e:
                raise AssertionError(
                    'No attr found for parameter {} of distribution {}: \n{}'.
                    format(k, dist_name, e))

        for stat in data.draw(
                hps.permutations([
                    'covariance', 'entropy', 'mean', 'mode', 'stddev',
                    'variance'
                ]))[:3]:
            logging.info('%s.%s', dist_name, stat)
            try:
                VAR_USAGES.clear()
                getattr(dist, stat)()
                assert_no_excessive_var_usage('statistic `{}` of `{}`'.format(
                    stat, dist))
            except NotImplementedError:
                pass

        VAR_USAGES.clear()
        with tf.GradientTape() as tape:
            sample = dist.sample()
        assert_no_excessive_var_usage('method `sample` of `{}`'.format(dist))
        if dist.reparameterization_type == tfd.FULLY_REPARAMETERIZED:
            grads = tape.gradient(sample, dist.variables)
            for grad, var in zip(grads, dist.variables):
                if grad is None:
                    raise AssertionError(
                        'Missing sample -> {} grad for distribution {}'.format(
                            var, dist_name))

        # Turn off validations, since log_prob can choke on dist's own samples.
        dist = dist.copy(validate_args=False)
        if dist_name not in NO_LOG_PROB_PARAM_GRADS:
            with tf.GradientTape() as tape:
                lp = dist.log_prob(tf.stop_gradient(sample))
            grads = tape.gradient(lp, dist.variables)
            for grad, var in zip(grads, dist.variables):
                if grad is None:
                    raise AssertionError(
                        'Missing log_prob -> {} grad for distribution {}'.
                        format(var, dist_name))

        for evaluative in data.draw(
                hps.permutations([
                    'log_prob', 'prob', 'log_cdf', 'cdf',
                    'log_survival_function', 'survival_function'
                ]))[:3]:
            logging.info('%s.%s', dist_name, evaluative)
            try:
                VAR_USAGES.clear()
                getattr(dist, evaluative)(sample)
                assert_no_excessive_var_usage(
                    'evaluative `{}` of `{}`'.format(evaluative, dist),
                    max_permissible=1)  # No validation => 1 convert.
            except NotImplementedError:
                pass