Exemple #1
0
  def testExcessiveConcretizationWithDefaultReinterpretedBatchNdims(self):
    loc = tfp_hps.defer_and_count_usage(
        tf.Variable(np.zeros((5, 2, 3)), shape=tf.TensorShape(None)))
    scale = tfp_hps.defer_and_count_usage(
        tf.Variable(np.ones([]), shape=tf.TensorShape(None)))
    dist = tfd.Independent(
        tfd.Logistic(loc=loc, scale=scale, validate_args=True),
        reinterpreted_batch_ndims=None, validate_args=True)

    for method in ('batch_shape_tensor', 'event_shape_tensor',
                   'mean', 'variance', 'sample'):
      with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=4):
        getattr(dist, method)()

    # In addition to the four reads of `loc`, `scale` described above in
    # `testExcessiveConcretizationOfParams`, the methods below have two more
    # reads of these parameters -- from computing a default value for
    # `reinterpreted_batch_ndims`, which requires calling
    # `dist.distribution.batch_shape_tensor()`.

    for method in ('log_prob', 'log_cdf', 'prob', 'cdf'):
      with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=6):
        getattr(dist, method)(np.zeros((4, 5, 2, 3)))

    with tfp_hps.assert_no_excessive_var_usage('entropy', max_permissible=6):
      dist.entropy()

    # `Distribution.survival_function` and `Distribution.log_survival_function`
    # will call `Distribution.cdf` and `Distribution.log_cdf`, resulting in
    # one additional call to `Independent._parameter_control_dependencies`,
    # and thus two additional concretizations of the parameters.

    for method in ('survival_function', 'log_survival_function'):
      with tfp_hps.assert_no_excessive_var_usage(method, max_permissible=8):
        getattr(dist, method)(np.zeros((4, 5, 2, 3)))
  def test_scalar_distributions(self):
    self.dist1 = tfd.Normal(
        loc=self.maybe_static(
            tf.zeros(self.batch_dim_1, dtype=self.dtype),
            self.is_static),
        scale=self.maybe_static(
            tf.ones(self.batch_dim_1, dtype=self.dtype),
            self.is_static)
    )
    self.dist2 = tfd.Logistic(
        loc=self.maybe_static(
            tf.zeros(self.batch_dim_2, dtype=self.dtype),
            self.is_static),
        scale=self.maybe_static(
            tf.ones(self.batch_dim_2, dtype=self.dtype),
            self.is_static)
    )
    self.dist3 = tfd.Exponential(
        rate=self.maybe_static(
            tf.ones(self.batch_dim_3, dtype=self.dtype),
            self.is_static)
    )
    concat_dist = batch_concat.BatchConcat(
        distributions=[self.dist1, self.dist2, self.dist3], axis=1,
        validate_args=False)
    self.assertAllEqual(
        self.evaluate(concat_dist.batch_shape_tensor()),
        [2, 6, 4])

    seed = test_util.test_seed()
    samples = concat_dist.sample(seed=seed)
    self.assertAllEqual(self.evaluate(tf.shape(samples)), [2, 6, 4])
Exemple #3
0
 def testGradientsThroughParams(self):
     loc = tf.Variable(np.zeros((4, 5, 2, 3)), shape=tf.TensorShape(None))
     scale = tf.Variable(np.ones([]), shape=tf.TensorShape(None))
     ndims = tf.Variable(2, trainable=False, shape=tf.TensorShape(None))
     dist = tfd.Independent(tfd.Logistic(loc=loc, scale=scale),
                            reinterpreted_batch_ndims=ndims,
                            validate_args=True)
     with tf.GradientTape() as tape:
         loss = -dist.log_prob(np.zeros((4, 5, 2, 3)))
     grad = tape.gradient(loss, dist.trainable_variables)
     self.assertLen(grad, 2)
     self.assertAllNotNone(grad)
    def testExcessiveConcretizationOfParams(self):
        loc = tfp_hps.defer_and_count_usage(
            tf.Variable(np.zeros((4, 2, 2)), shape=tf.TensorShape(None)))
        scale = tfp_hps.defer_and_count_usage(
            tf.Variable(np.ones([]), shape=tf.TensorShape(None)))
        ndims = tf.Variable(1, trainable=False, shape=tf.TensorShape(None))
        dist = tfd.Independent(tfd.Logistic(loc=loc,
                                            scale=scale,
                                            validate_args=True),
                               reinterpreted_batch_ndims=ndims,
                               validate_args=True)

        # TODO(b/140579567): All methods of `dist` may require four concretizations
        # of parameters `loc` and `scale`:
        #  - `Independent._parameter_control_dependencies` calls
        #    `Logistic.batch_shape_tensor`, which:
        #    * Reads `loc`, `scale` in `Logistic._parameter_control_dependencies`.
        #    * Reads `loc`, `scale` in `Logistic._batch_shape_tensor`.
        #  - The method `dist.m` will call `dist.self.m`, which:
        #    * Reads `loc`, `scale` in `Logistic._parameter_control_dependencies`.
        #    * Reads `loc`, `scale` in the implementation of method  `Logistic._m`.
        #
        # NOTE: If `dist.distribution` had dynamic batch shape and event shape,
        # there could be two more reads of the parameters of `dist.distribution`
        # in `dist.event_shape_tensor`, from calling
        # `dist.distribution.event_shape_tensor()`.

        for method in ('batch_shape_tensor', 'event_shape_tensor', 'mode',
                       'stddev', 'entropy'):
            with tfp_hps.assert_no_excessive_var_usage(method,
                                                       max_permissible=4):
                getattr(dist, method)()

        with tfp_hps.assert_no_excessive_var_usage('sample',
                                                   max_permissible=4):
            dist.sample(seed=test_util.test_seed())

        for method in ('log_prob', 'log_cdf', 'prob', 'cdf'):
            with tfp_hps.assert_no_excessive_var_usage(method,
                                                       max_permissible=4):
                getattr(dist, method)(np.zeros((3, 4, 2, 2)))

        # `Distribution.survival_function` and `Distribution.log_survival_function`
        # will call `Distribution.cdf` and `Distribution.log_cdf`, resulting in
        # one additional call to `Independent._parameter_control_dependencies`,
        # and thus two additional concretizations of the parameters.

        for method in ('survival_function', 'log_survival_function'):
            with tfp_hps.assert_no_excessive_var_usage(method,
                                                       max_permissible=6):
                getattr(dist, method)(np.zeros((3, 4, 2, 2)))
Exemple #5
0
 def test_variable_sample_shape_exception(self):
   loc = tf.Variable(tf.zeros([4, 5, 3]), shape=tf.TensorShape(None))
   scale = tf.Variable(tf.ones([]), shape=tf.TensorShape(None))
   sample_shape = tf.Variable([[1, 2]], shape=tf.TensorShape(None))
   with self.assertRaisesWithPredicateMatch(
       Exception,
       'Argument `sample_shape` must be either a scalar or a vector.'):
     dist = tfd.Sample(
         tfd.Independent(tfd.Logistic(loc=loc, scale=scale),
                         reinterpreted_batch_ndims=1),
         sample_shape=sample_shape,
         validate_args=True)
     self.evaluate([v.initializer for v in dist.trainable_variables])
     self.evaluate(dist.mean())
Exemple #6
0
  def test_variable_shape_change(self):
    loc = tf.Variable(tf.zeros([4, 5, 3]), shape=tf.TensorShape(None))
    scale = tf.Variable(tf.ones([]), shape=tf.TensorShape(None))
    # In real life, you'd really always want `sample_shape` to be
    # `trainable=False`.
    sample_shape = tf.Variable([1, 2], shape=tf.TensorShape(None))
    dist = tfd.Sample(
        tfd.Independent(tfd.Logistic(loc=loc, scale=scale),
                        reinterpreted_batch_ndims=1),
        sample_shape=sample_shape,
        validate_args=True)
    self.evaluate([v.initializer for v in dist.trainable_variables])

    x = dist.mean()
    y = dist.sample([7, 2], seed=test_util.test_seed())
    loss_x = -dist.log_prob(x)
    loss_0 = -dist.log_prob(0.)
    batch_shape = dist.batch_shape_tensor()
    event_shape = dist.event_shape_tensor()
    [x_, y_, loss_x_, loss_0_, batch_shape_, event_shape_] = self.evaluate([
        x, y, loss_x, loss_0, batch_shape, event_shape])
    self.assertAllEqual([4, 5, 1, 2, 3], x_.shape)
    self.assertAllEqual([7, 2, 4, 5, 1, 2, 3], y_.shape)
    self.assertAllEqual([4, 5], loss_x_.shape)
    self.assertAllEqual([4, 5], loss_0_.shape)
    self.assertAllEqual([4, 5], batch_shape_)
    self.assertAllEqual([1, 2, 3], event_shape_)
    self.assertLen(dist.trainable_variables, 3)

    with tf.control_dependencies([
        loc.assign(tf.zeros([])),
        scale.assign(tf.ones([3, 1, 2])),
        sample_shape.assign(6),
    ]):
      x = dist.mean()
      y = dist.sample([7, 2], seed=test_util.test_seed())
      loss_x = -dist.log_prob(x)
      loss_0 = -dist.log_prob(0.)
      batch_shape = dist.batch_shape_tensor()
      event_shape = dist.event_shape_tensor()
    [x_, y_, loss_x_, loss_0_, batch_shape_, event_shape_] = self.evaluate([
        x, y, loss_x, loss_0, batch_shape, event_shape])
    self.assertAllEqual([3, 1, 6, 2], x_.shape)
    self.assertAllEqual([7, 2, 3, 1, 6, 2], y_.shape)
    self.assertAllEqual([3, 1], loss_x_.shape)
    self.assertAllEqual([3, 1], loss_0_.shape)
    self.assertAllEqual([3, 1], batch_shape_)
    self.assertAllEqual([6, 2], event_shape_)
    self.assertLen(dist.trainable_variables, 3)
Exemple #7
0
 def test_gradients_through_params(self):
     loc = tf.Variable(tf.zeros([4, 5, 3]), shape=tf.TensorShape(None))
     scale = tf.Variable(tf.ones([]), shape=tf.TensorShape(None))
     # In real life, you'd really always want `sample_shape` to be
     # `trainable=False`.
     sample_shape = tf.Variable([1, 2], shape=tf.TensorShape(None))
     dist = tfd.Sample(tfd.Independent(tfd.Logistic(loc=loc, scale=scale),
                                       reinterpreted_batch_ndims=1),
                       sample_shape=sample_shape,
                       validate_args=True)
     with tf.GradientTape() as tape:
         loss = -dist.log_prob(0.)
     self.assertLen(dist.trainable_variables, 3)
     grad = tape.gradient(loss, [loc, scale, sample_shape])
     self.assertAllNotNone(grad[:-1])
     self.assertIs(grad[-1], None)
Exemple #8
0
    def testChangingVariableShapes(self):
        if not tf.executing_eagerly():
            return

        loc = tf.Variable(np.zeros((4, 5, 2, 3)), shape=tf.TensorShape(None))
        scale = tf.Variable(np.ones([]), shape=tf.TensorShape(None))
        dist = tfd.Independent(tfd.Logistic(loc=loc, scale=scale),
                               reinterpreted_batch_ndims=None,
                               validate_args=True)

        self.assertAllEqual((4, ), dist.batch_shape_tensor())

        loc.assign(np.zeros((3, 7, 1, 1, 1)))
        self.assertAllEqual((3, ), dist.batch_shape_tensor())
        self.assertAllEqual(
            (2, 3), tf.shape(dist.log_prob(np.zeros((2, 3, 7, 1, 1, 1)))))
Exemple #9
0
 def new(params, event_shape=(), validate_args=False, name=None):
     """Create the distribution instance from a `params` vector."""
     with tf.name_scope(name, 'IndependentLogistic', [params, event_shape]):
         params = tf.convert_to_tensor(params, name='params')
         event_shape = dist_util.expand_to_vector(tf.convert_to_tensor(
             event_shape, name='event_shape', preferred_dtype=tf.int32),
                                                  tensor_name='event_shape')
         output_shape = tf.concat([
             tf.shape(params)[:-1],
             event_shape,
         ],
                                  axis=0)
         loc_params, scale_params = tf.split(params, 2, axis=-1)
         return tfd.Independent(
             tfd.Logistic(loc=tf.reshape(loc_params, output_shape),
                          scale=tf.math.softplus(
                              tf.reshape(scale_params, output_shape)),
                          validate_args=validate_args),
             reinterpreted_batch_ndims=tf.size(event_shape),
             validate_args=validate_args)