Пример #1
0
    def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self):
        with self.test_session():
            # df = 0.5 ==> undefined mean ==> undefined variance.
            # df = 1.5 ==> infinite variance.
            df = [0.5, 1.5, 3., 5., 7.]
            mu = [-2, 0., 1., 3.3, 4.4]
            sigma = [5., 4., 3., 2., 1.]
            student = student_t.StudentT(df=df,
                                         loc=mu,
                                         scale=sigma,
                                         allow_nan_stats=True)
            var = student.variance().eval()
            ## scipy uses inf for variance when the mean is undefined.  When mean is
            # undefined we say variance is undefined as well.  So test the first
            # member of var, making sure it is NaN, then replace with inf and compare
            # to scipy.
            self.assertTrue(np.isnan(var[0]))
            var[0] = np.inf

            if not stats:
                return
            expected_var = [
                stats.t.var(d, loc=m, scale=s)
                for (d, m, s) in zip(df, mu, sigma)
            ]
            self.assertAllClose(expected_var, var)
Пример #2
0
 def testStudentSampleMultiDimensional(self):
   with self.test_session():
     batch_size = 7
     df = constant_op.constant([[3., 7.]] * batch_size)
     mu = constant_op.constant([[3., -3.]] * batch_size)
     sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] *
                                  batch_size)
     df_v = [3., 7.]
     mu_v = [3., -3.]
     sigma_v = [np.sqrt(10.), np.sqrt(15.)]
     n = constant_op.constant(200000)
     student = student_t.StudentT(df=df, loc=mu, scale=sigma)
     samples = student.sample(n, seed=123456)
     sample_values = self.evaluate(samples)
     self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
     self.assertAllClose(
         sample_values[:, 0, 0].mean(), mu_v[0], rtol=1e-2, atol=0)
     self.assertAllClose(
         sample_values[:, 0, 0].var(),
         sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
         rtol=1e-1,
         atol=0)
     self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
     self.assertAllClose(
         sample_values[:, 0, 1].mean(), mu_v[1], rtol=1e-2, atol=0)
     self.assertAllClose(
         sample_values[:, 0, 1].var(),
         sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
         rtol=1e-1,
         atol=0)
     self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 1])
Пример #3
0
 def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
     mu = [1., 3.3, 4.4]
     student = student_t.StudentT(df=[3., 5., 7.],
                                  loc=mu,
                                  scale=[3., 2., 1.])
     mean = self.evaluate(student.mean())
     self.assertAllClose([1., 3.3, 4.4], mean)
Пример #4
0
 def testPdfOfSampleMultiDims(self):
   student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.)
   self.assertAllEqual([], student.event_shape)
   self.assertAllEqual([], self.evaluate(student.event_shape_tensor()))
   self.assertAllEqual([2, 2], student.batch_shape)
   self.assertAllEqual([2, 2], self.evaluate(student.batch_shape_tensor()))
   num = 50000
   samples = student.sample(num, seed=123456)
   pdfs = student.prob(samples)
   sample_vals, pdf_vals = self.evaluate([samples, pdfs])
   self.assertEqual(samples.get_shape(), (num, 2, 2))
   self.assertEqual(pdfs.get_shape(), (num, 2, 2))
   self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03)
   self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03)
   self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
   self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
   self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
   self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
   if not stats:
     return
   self.assertNear(
       stats.t.var(7., loc=0., scale=3.),  # loc d.n. effect var
       np.var(sample_vals[:, :, 0]),
       err=.4)
   self.assertNear(
       stats.t.var(11., loc=0., scale=3.),  # loc d.n. effect var
       np.var(sample_vals[:, :, 1]),
       err=.4)
Пример #5
0
  def testStudentPDFAndLogPDF(self):
    with self.test_session():
      batch_size = 6
      df = constant_op.constant([3.] * batch_size)
      mu = constant_op.constant([7.] * batch_size)
      sigma = constant_op.constant([8.] * batch_size)
      df_v = 3.
      mu_v = 7.
      sigma_v = 8.
      t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
      student = student_t.StudentT(df, loc=mu, scale=-sigma)

      log_pdf = student.log_prob(t)
      self.assertEquals(log_pdf.get_shape(), (6,))
      log_pdf_values = self.evaluate(log_pdf)
      pdf = student.prob(t)
      self.assertEquals(pdf.get_shape(), (6,))
      pdf_values = self.evaluate(pdf)

      if not stats:
        return

      expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
      expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
      self.assertAllClose(expected_log_pdf, log_pdf_values)
      self.assertAllClose(np.log(expected_pdf), log_pdf_values)
      self.assertAllClose(expected_pdf, pdf_values)
      self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
Пример #6
0
    def testStudentCDFAndLogCDF(self):
        batch_size = 6
        df = constant_op.constant([3.] * batch_size)
        mu = constant_op.constant([7.] * batch_size)
        sigma = constant_op.constant([-8.] * batch_size)
        df_v = 3.
        mu_v = 7.
        sigma_v = 8.
        t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32)
        student = student_t.StudentT(df, loc=mu, scale=sigma)

        log_cdf = student.log_cdf(t)
        self.assertEquals(log_cdf.get_shape(), (6, ))
        log_cdf_values = self.evaluate(log_cdf)
        cdf = student.cdf(t)
        self.assertEquals(cdf.get_shape(), (6, ))
        cdf_values = self.evaluate(cdf)

        if not stats:
            return
        expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v)
        expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v)
        self.assertAllClose(expected_log_cdf,
                            log_cdf_values,
                            atol=0.,
                            rtol=1e-5)
        self.assertAllClose(np.log(expected_cdf),
                            log_cdf_values,
                            atol=0.,
                            rtol=1e-5)
        self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5)
        self.assertAllClose(np.exp(expected_log_cdf),
                            cdf_values,
                            atol=0.,
                            rtol=1e-5)
Пример #7
0
    def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
        # df <= 1 ==> variance not defined
        student = student_t.StudentT(df=1.,
                                     loc=0.,
                                     scale=1.,
                                     allow_nan_stats=False)
        with self.assertRaisesOpError("x < y"):
            self.evaluate(student.variance())

        # df <= 1 ==> variance not defined
        student = student_t.StudentT(df=0.5,
                                     loc=0.,
                                     scale=1.,
                                     allow_nan_stats=False)
        with self.assertRaisesOpError("x < y"):
            self.evaluate(student.variance())
Пример #8
0
  def testStudentLogPDFMultidimensional(self):
    with self.test_session():
      batch_size = 6
      df = constant_op.constant([[1.5, 7.2]] * batch_size)
      mu = constant_op.constant([[3., -3.]] * batch_size)
      sigma = constant_op.constant([[-math.sqrt(10.), math.sqrt(15.)]] *
                                   batch_size)
      df_v = np.array([1.5, 7.2])
      mu_v = np.array([3., -3.])
      sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)])
      t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T
      student = student_t.StudentT(df, loc=mu, scale=sigma)
      log_pdf = student.log_prob(t)
      log_pdf_values = self.evaluate(log_pdf)
      self.assertEqual(log_pdf.get_shape(), (6, 2))
      pdf = student.prob(t)
      pdf_values = self.evaluate(pdf)
      self.assertEqual(pdf.get_shape(), (6, 2))

      if not stats:
        return
      expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v)
      expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v)
      self.assertAllClose(expected_log_pdf, log_pdf_values)
      self.assertAllClose(np.log(expected_pdf), log_pdf_values)
      self.assertAllClose(expected_pdf, pdf_values)
      self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
Пример #9
0
 def testNegativeDofFails(self):
     with self.assertRaisesOpError(r"Condition x > 0 did not hold"):
         student = student_t.StudentT(df=[2, -5.],
                                      loc=0.,
                                      scale=1.,
                                      validate_args=True,
                                      name="S")
         self.evaluate(student.mean())
Пример #10
0
 def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self):
   with self.test_session():
     mu = [1., 3.3, 4.4]
     student = student_t.StudentT(
         df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.],
         allow_nan_stats=False)
     with self.assertRaisesOpError("x < y"):
       self.evaluate(student.mean())
Пример #11
0
 def testMode(self):
     df = [0.5, 1., 3]
     mu = [-1, 0., 1]
     sigma = [5., 4., 3.]
     student = student_t.StudentT(df=df, loc=mu, scale=sigma)
     # Test broadcast of mu across shape of df/sigma
     mode = self.evaluate(student.mode())
     self.assertAllClose([-1., 0, 1], mode)
Пример #12
0
 def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self):
   with self.test_session():
     mu = [-2, 0., 1., 3.3, 4.4]
     sigma = [5., 4., 3., 2., 1.]
     student = student_t.StudentT(
         df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma,
         allow_nan_stats=True)
     mean = self.evaluate(student.mean())
     self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean)
Пример #13
0
  def testStudentSampleMultipleTimes(self):
    with self.test_session():
      df = constant_op.constant(4.)
      mu = constant_op.constant(3.)
      sigma = constant_op.constant(math.sqrt(10.))
      n = constant_op.constant(100)

      random_seed.set_random_seed(654321)
      student = student_t.StudentT(
          df=df, loc=mu, scale=sigma, name="student_t1")
      samples1 = self.evaluate(student.sample(n, seed=123456))

      random_seed.set_random_seed(654321)
      student2 = student_t.StudentT(
          df=df, loc=mu, scale=sigma, name="student_t2")
      samples2 = self.evaluate(student2.sample(n, seed=123456))

      self.assertAllClose(samples1, samples2)
Пример #14
0
 def testStudentSampleSmallDfNoNan(self):
     df_v = [1e-1, 1e-5, 1e-10, 1e-20]
     df = constant_op.constant(df_v)
     n = constant_op.constant(200000)
     student = student_t.StudentT(df=df, loc=1., scale=1.)
     samples = student.sample(n, seed=123456)
     sample_values = self.evaluate(samples)
     n_val = 200000
     self.assertEqual(sample_values.shape, (n_val, 4))
     self.assertTrue(np.all(np.logical_not(np.isnan(sample_values))))
Пример #15
0
  def testBroadcastingPdfArgs(self):

    def _assert_shape(student, arg, shape):
      self.assertEqual(student.log_prob(arg).get_shape(), shape)
      self.assertEqual(student.prob(arg).get_shape(), shape)

    def _check(student):
      _assert_shape(student, 2., (3,))
      xs = np.array([2., 3., 4.], dtype=np.float32)
      _assert_shape(student, xs, (3,))
      xs = np.array([xs])
      _assert_shape(student, xs, (1, 3))
      xs = xs.T
      _assert_shape(student, xs, (3, 3))

    _check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
    _check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
    _check(student_t.StudentT(df=7., loc=3., scale=[2., 3., 4.,]))

    def _check2d(student):
      _assert_shape(student, 2., (1, 3))
      xs = np.array([2., 3., 4.], dtype=np.float32)
      _assert_shape(student, xs, (1, 3))
      xs = np.array([xs])
      _assert_shape(student, xs, (1, 3))
      xs = xs.T
      _assert_shape(student, xs, (3, 3))

    _check2d(student_t.StudentT(df=[[2., 3., 4.,]], loc=2., scale=1.))
    _check2d(student_t.StudentT(df=7., loc=[[2., 3., 4.,]], scale=1.))
    _check2d(student_t.StudentT(df=7., loc=3., scale=[[2., 3., 4.,]]))

    def _check2d_rows(student):
      _assert_shape(student, 2., (3, 1))
      xs = np.array([2., 3., 4.], dtype=np.float32)  # (3,)
      _assert_shape(student, xs, (3, 3))
      xs = np.array([xs])  # (1,3)
      _assert_shape(student, xs, (3, 3))
      xs = xs.T  # (3,1)
      _assert_shape(student, xs, (3, 1))

    _check2d_rows(student_t.StudentT(df=[[2.], [3.], [4.]], loc=2., scale=1.))
    _check2d_rows(student_t.StudentT(df=7., loc=[[2.], [3.], [4.]], scale=1.))
    _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]]))
Пример #16
0
 def testFullyReparameterized(self):
     df = constant_op.constant(2.0)
     mu = constant_op.constant(1.0)
     sigma = constant_op.constant(3.0)
     with backprop.GradientTape() as tape:
         tape.watch(df)
         tape.watch(mu)
         tape.watch(sigma)
         student = student_t.StudentT(df=df, loc=mu, scale=sigma)
         samples = student.sample(100)
     grad_df, grad_mu, grad_sigma = tape.gradient(samples, [df, mu, sigma])
     self.assertIsNotNone(grad_df)
     self.assertIsNotNone(grad_mu)
     self.assertIsNotNone(grad_sigma)
Пример #17
0
    def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers(
            self):
        # df = 1.5 ==> infinite variance.
        df = [1.5, 3., 5., 7.]
        mu = [0., 1., 3.3, 4.4]
        sigma = [4., 3., 2., 1.]
        student = student_t.StudentT(df=df, loc=mu, scale=sigma)
        var = self.evaluate(student.variance())

        if not stats:
            return
        expected_var = [
            stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
        ]
        self.assertAllClose(expected_var, var)
Пример #18
0
    def testStd(self):
        # Defined for all batch members.
        df = [3.5, 5., 3., 5., 7.]
        mu = [-2.2]
        sigma = [5., 4., 3., 2., 1.]
        student = student_t.StudentT(df=df, loc=mu, scale=sigma)
        # Test broadcast of mu across shape of df/sigma
        stddev = self.evaluate(student.stddev())
        mu *= len(df)

        if not stats:
            return
        expected_stddev = [
            stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
        ]
        self.assertAllClose(expected_stddev, stddev)
Пример #19
0
 def testStudentSample(self):
     df = constant_op.constant(4.)
     mu = constant_op.constant(3.)
     sigma = constant_op.constant(-math.sqrt(10.))
     df_v = 4.
     mu_v = 3.
     sigma_v = np.sqrt(10.)
     n = constant_op.constant(200000)
     student = student_t.StudentT(df=df, loc=mu, scale=sigma)
     samples = student.sample(n, seed=123456)
     sample_values = self.evaluate(samples)
     n_val = 200000
     self.assertEqual(sample_values.shape, (n_val, ))
     self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
     self.assertAllClose(sample_values.var(),
                         sigma_v**2 * df_v / (df_v - 2),
                         rtol=0.1,
                         atol=0)
     self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
    def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self):
        # df = 0.5 ==> undefined mean ==> undefined variance.
        # df = 1.5 ==> infinite variance.
        df = [0.5, 1.5, 3., 5., 7.]
        mu = [-2, 0., 1., 3.3, 4.4]
        sigma = [5., 4., 3., 2., 1.]
        student = student_t.StudentT(df=df,
                                     loc=mu,
                                     scale=sigma,
                                     allow_nan_stats=True)
        var = self.evaluate(student.variance())

        if not stats:
            return
        expected_var = [
            stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)
        ]
        # Slicing off first element due to nan/inf mismatch in different SciPy
        # versions.
        self.assertAllClose(expected_var[1:], var[1:])
Пример #21
0
    def testStudentEntropy(self):
        df_v = np.array([[2., 3., 7.]])  # 1x3
        mu_v = np.array([[1., -1, 0]])  # 1x3
        sigma_v = np.array([[1., -2., 3.]]).T  # transposed => 3x1
        student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v)
        ent = student.entropy()
        ent_values = self.evaluate(ent)

        # Help scipy broadcast to 3x3
        ones = np.array([[1, 1, 1]])
        sigma_bc = np.abs(sigma_v) * ones
        mu_bc = ones.T * mu_v
        df_bc = ones.T * df_v
        if not stats:
            return
        expected_entropy = stats.t.entropy(np.reshape(df_bc, [-1]),
                                           loc=np.reshape(mu_bc, [-1]),
                                           scale=np.reshape(sigma_bc, [-1]))
        expected_entropy = np.reshape(expected_entropy, df_bc.shape)
        self.assertAllClose(expected_entropy, ent_values)
Пример #22
0
 def testPdfOfSample(self):
   student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
   num = 20000
   samples = student.sample(num, seed=123456)
   pdfs = student.prob(samples)
   mean = student.mean()
   mean_pdf = student.prob(student.mean())
   sample_vals, pdf_vals, mean_val, mean_pdf_val = self.evaluate(
       [samples, pdfs, student.mean(), mean_pdf])
   self.assertEqual(samples.get_shape(), (num,))
   self.assertEqual(pdfs.get_shape(), (num,))
   self.assertEqual(mean.get_shape(), ())
   self.assertNear(np.pi, np.mean(sample_vals), err=0.02)
   self.assertNear(np.pi, mean_val, err=1e-6)
   # Verify integral over sample*pdf ~= 1.
   # Tolerance increased since eager was getting a value of 1.002041.
   self._assertIntegral(sample_vals, pdf_vals, err=3e-3)
   if not stats:
     return
   self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
Пример #23
0
 def testPdfOfSample(self):
     with self.test_session() as sess:
         student = student_t.StudentT(df=3., loc=np.pi, scale=1.)
         num = 20000
         samples = student.sample(num, seed=123456)
         pdfs = student.prob(samples)
         mean = student.mean()
         mean_pdf = student.prob(student.mean())
         sample_vals, pdf_vals, mean_val, mean_pdf_val = sess.run(
             [samples, pdfs, student.mean(), mean_pdf])
         self.assertEqual(samples.get_shape(), (num, ))
         self.assertEqual(pdfs.get_shape(), (num, ))
         self.assertEqual(mean.get_shape(), ())
         self.assertNear(np.pi, np.mean(sample_vals), err=0.02)
         self.assertNear(np.pi, mean_val, err=1e-6)
         # Verify integral over sample*pdf ~= 1.
         self._assertIntegral(sample_vals, pdf_vals, err=2e-3)
         if not stats:
             return
         self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi),
                         mean_pdf_val,
                         err=1e-6)
    def __init__(self,
                 df,
                 loc=None,
                 scale_identity_multiplier=None,
                 scale_diag=None,
                 scale_tril=None,
                 scale_perturb_factor=None,
                 scale_perturb_diag=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VectorStudentT"):
        """Instantiates the vector Student's t-distributions on `R^k`.

    The `batch_shape` is the broadcast between `df.batch_shape` and
    `Affine.batch_shape` where `Affine` is constructed from `loc` and
    `scale_*` arguments.

    The `event_shape` is the event shape of `Affine.event_shape`.

    Args:
      df: Floating-point `Tensor`. The degrees of freedom of the
        distribution(s). `df` must contain only positive values. Must be
        scalar if `loc`, `scale_*` imply non-scalar batch_shape or must have the
        same `batch_shape` implied by `loc`, `scale_*`.
      loc: Floating-point `Tensor`. If this is set to `None`, no `loc` is
        applied.
      scale_identity_multiplier: floating point rank 0 `Tensor` representing a
        scaling done to the identity matrix. When `scale_identity_multiplier =
        scale_diag=scale_tril = None` then `scale += IdentityMatrix`. Otherwise
        no scaled-identity-matrix is added to `scale`.
      scale_diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ..., k], which represents a k x k
        diagonal matrix. When `None` no diagonal term is added to `scale`.
      scale_tril: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ..., k, k], which represents a k x k
        lower triangular matrix. When `None` no `scale_tril` term is added to
        `scale`. The upper triangular elements above the diagonal are ignored.
      scale_perturb_factor: Floating-point `Tensor` representing factor matrix
        with last two dimensions of shape `(k, r)`. When `None`, no rank-r
        update is added to `scale`.
      scale_perturb_diag: Floating-point `Tensor` representing the diagonal
        matrix. `scale_perturb_diag` has shape [N1, N2, ..., r], which
        represents an r x r Diagonal matrix. When `None` low rank updates will
        take the form `scale_perturb_factor * scale_perturb_factor.T`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        graph_parents = [
            df, loc, scale_identity_multiplier, scale_diag, scale_tril,
            scale_perturb_factor, scale_perturb_diag
        ]
        with ops.name_scope(name) as name:
            with ops.name_scope("init", values=graph_parents):
                # The shape of the _VectorStudentT distribution is governed by the
                # relationship between df.batch_shape and affine.batch_shape. In
                # pseudocode the basic procedure is:
                #   if df.batch_shape is scalar:
                #     if affine.batch_shape is not scalar:
                #       # broadcast distribution.sample so
                #       # it has affine.batch_shape.
                #     self.batch_shape = affine.batch_shape
                #   else:
                #     if affine.batch_shape is scalar:
                #       # let affine broadcasting do its thing.
                #     self.batch_shape = df.batch_shape
                # All of the above magic is actually handled by TransformedDistribution.
                # Here we really only need to collect the affine.batch_shape and decide
                # what we're going to pass in to TransformedDistribution's
                # (override) batch_shape arg.
                affine = bijectors.Affine(
                    shift=loc,
                    scale_identity_multiplier=scale_identity_multiplier,
                    scale_diag=scale_diag,
                    scale_tril=scale_tril,
                    scale_perturb_factor=scale_perturb_factor,
                    scale_perturb_diag=scale_perturb_diag,
                    validate_args=validate_args)
                distribution = student_t.StudentT(
                    df=df,
                    loc=array_ops.zeros([], dtype=affine.dtype),
                    scale=array_ops.ones([], dtype=affine.dtype))
                batch_shape, override_event_shape = (
                    distribution_util.shapes_from_loc_and_scale(
                        affine.shift, affine.scale))
                override_batch_shape = distribution_util.pick_vector(
                    distribution.is_scalar_batch(), batch_shape,
                    constant_op.constant([], dtype=dtypes.int32))
                super(_VectorStudentT,
                      self).__init__(distribution=distribution,
                                     bijector=affine,
                                     batch_shape=override_batch_shape,
                                     event_shape=override_event_shape,
                                     validate_args=validate_args,
                                     name=name)
                self._parameters = parameters