Exemple #1
0
    def testCholeskyExtensionRandomized(self, data):
        jitter = lambda n: tf.linalg.eye(n, dtype=self.dtype) * 1e-5
        target_bs = data.draw(hpnp.array_shapes())
        prev_bs, new_bs = data.draw(
            tfp_test_util.broadcasting_shapes(target_bs, 2))
        ones = tf.TensorShape([1] * len(target_bs))
        smallest_shared_shp = tuple(
            np.min([
                tf.broadcast_static_shape(ones, shp).as_list()
                for shp in [prev_bs, new_bs]
            ],
                   axis=0))

        z = data.draw(hps.integers(min_value=1, max_value=12))
        n = data.draw(hps.integers(min_value=0, max_value=z - 1))
        m = z - n

        np.random.seed(
            data.draw(hps.integers(min_value=0, max_value=2**32 - 1)))
        xs = np.random.uniform(size=smallest_shared_shp + (n, ))
        data.draw(hps.just(xs))
        xs = (xs + np.zeros(prev_bs.as_list() + [n]))[..., np.newaxis]
        xs = xs.astype(self.dtype)
        xs = tf1.placeholder_with_default(
            xs, shape=xs.shape if self.use_static_shape else None)

        k = tfp.positive_semidefinite_kernels.MaternOneHalf()
        mat = k.matrix(xs, xs) + jitter(n)
        chol = tf.linalg.cholesky(mat)

        ys = np.random.uniform(size=smallest_shared_shp + (m, ))
        data.draw(hps.just(ys))
        ys = (ys + np.zeros(new_bs.as_list() + [m]))[..., np.newaxis]
        ys = ys.astype(self.dtype)
        ys = tf1.placeholder_with_default(
            ys, shape=ys.shape if self.use_static_shape else None)

        xsys = tf.concat([
            xs + tf.zeros(target_bs + (n, 1), dtype=self.dtype),
            ys + tf.zeros(target_bs + (m, 1), dtype=self.dtype)
        ],
                         axis=-2)
        new_chol_expected = tf.linalg.cholesky(
            k.matrix(xsys, xsys) + jitter(z))

        new_chol = tfp.math.cholesky_concat(
            chol,
            k.matrix(xsys, ys) + jitter(z)[:, n:])
        self.assertAllClose(new_chol_expected, new_chol, rtol=1e-5, atol=1e-5)
def broadcasting_shapes(draw, batch_shape, param_names):
    """Draws a set of parameter batch shapes that broadcast to `batch_shape`.

  For each parameter we need to choose its batch rank, and whether or not each
  axis i is 1 or batch_shape[i]. This function chooses a set of shapes that
  have possibly mismatched ranks, and possibly broadcasting axes, with the
  promise that the broadcast of the set of all shapes matches `batch_shape`.

  Args:
    draw: Hypothesis sampler.
    batch_shape: `tf.TensorShape`, the target (fully-defined) batch shape .
    param_names: Iterable of `str`, the parameters whose batch shapes need
      determination.

  Returns:
    param_batch_shapes: `dict` of `str->tf.TensorShape` where the set of
        shapes broadcast to `batch_shape`. The shapes are fully defined.
  """
    n = len(param_names)
    return dict(
        zip(draw(hps.permutations(param_names)),
            draw(tfp_test_util.broadcasting_shapes(batch_shape, n))))
def broadcast_compatible_shape(draw, batch_shape):
    """Draws a shape which is broadcast-compatible with `batch_shape`."""
    # broadcasting_shapes draws a sequence of shapes, so that the last "completes"
    # the broadcast to fill out batch_shape. Here we just draw two and take the
    # first (incomplete) one.
    return draw(tfp_test_util.broadcasting_shapes(batch_shape, 2))[0]