def _check_shapes(self):
    """Static check that shapes are compatible."""
    # Broadcast shape also checks that u and v are compatible.
    uv_shape = array_ops.broadcast_static_shape(
        self.u.get_shape(), self.v.get_shape())

    batch_shape = array_ops.broadcast_static_shape(
        self.base_operator.batch_shape, uv_shape[:-2])

    self.base_operator.domain_dimension.assert_is_compatible_with(
        uv_shape[-2])

    if self._diag_update is not None:
      uv_shape[-1].assert_is_compatible_with(self._diag_update.get_shape()[-1])
      array_ops.broadcast_static_shape(
          batch_shape, self._diag_update.get_shape()[:-1])
Esempio n. 2
0
  def _reduce_jacobian_det_over_event(
      self, y, ildj, min_event_ndims, event_ndims):
    """Reduce jacobian over event_ndims - min_event_ndims."""
    if not self.is_constant_jacobian:
      return math_ops.reduce_sum(
          ildj,
          self._get_event_reduce_dims(min_event_ndims, event_ndims))

    # In this case, we need to tile the jacobian over the event and reduce.
    y_rank = array_ops.rank(y)
    y_shape = array_ops.shape(y)[
        y_rank - event_ndims : y_rank - min_event_ndims]

    ones = array_ops.ones(y_shape, ildj.dtype)
    reduced_ildj = math_ops.reduce_sum(
        ones * ildj,
        axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
    # The multiplication by ones can change the inferred static shape so we try
    # to recover as much as possible.
    if (isinstance(event_ndims, int) and
        y.get_shape().ndims and ildj.get_shape().ndims):
      y_shape = y.get_shape()
      y_shape = y_shape[y_shape.ndims - event_ndims :
                        y_shape.ndims - min_event_ndims]
      ildj_shape = ildj.get_shape()
      broadcast_shape = array_ops.broadcast_static_shape(
          ildj_shape, y_shape)
      reduced_ildj.set_shape(
          broadcast_shape[: broadcast_shape.ndims - (
              event_ndims - min_event_ndims)])

    return reduced_ildj
  def benchmarkBatchMatMulBroadcast(self):
    for (a_shape, b_shape) in self.shape_pairs:
      with compat.forward_compatibility_horizon(2019, 4, 26):
        with ops.Graph().as_default(), \
            session.Session(config=benchmark.benchmark_config()) as sess, \
            ops.device("/cpu:0"):
          matrix_a = variables.Variable(
              GetRandomNormalInput(a_shape, np.float32))
          matrix_b = variables.Variable(
              GetRandomNormalInput(b_shape, np.float32))
          variables.global_variables_initializer().run()

          # Use batch matmul op's internal broadcasting.
          self.run_op_benchmark(
              sess,
              math_ops.matmul(matrix_a, matrix_b),
              min_iters=50,
              name="batch_matmul_cpu_{}_{}".format(a_shape, b_shape))

          # Manually broadcast the input matrices using the broadcast_to op.
          broadcasted_batch_shape = array_ops.broadcast_static_shape(
              matrix_a.shape[:-2], matrix_b.shape[:-2])
          broadcasted_a_shape = broadcasted_batch_shape.concatenate(
              matrix_a.shape[-2:])
          broadcasted_b_shape = broadcasted_batch_shape.concatenate(
              matrix_b.shape[-2:])
          self.run_op_benchmark(
              sess,
              math_ops.matmul(
                  array_ops.broadcast_to(matrix_a, broadcasted_a_shape),
                  array_ops.broadcast_to(matrix_b, broadcasted_b_shape)),
              min_iters=50,
              name="batch_matmul_manual_broadcast_cpu_{}_{}".format(
                  a_shape, b_shape))
  def _possibly_broadcast_batch_shape(self, x):
    """Return 'x', possibly after broadcasting the leading dimensions."""
    # If we have no batch shape, our batch shape broadcasts with everything!
    if self._batch_shape_arg is None:
      return x

    # Static attempt:
    #   If we determine that no broadcast is necessary, pass x through
    #   If we need a broadcast, add to an array of zeros.
    #
    # special_shape is the shape that, when broadcast with x's shape, will give
    # the correct broadcast_shape.  Note that
    #   We have already verified the second to last dimension of self.shape
    #   matches x's shape in assert_compatible_matrix_dimensions.
    #   Also, the final dimension of 'x' can have any shape.
    #   Therefore, the final two dimensions of special_shape are 1's.
    special_shape = self.batch_shape.concatenate([1, 1])
    bshape = array_ops.broadcast_static_shape(x.get_shape(), special_shape)
    if special_shape.is_fully_defined():
      # bshape.is_fully_defined iff special_shape.is_fully_defined.
      if bshape == x.get_shape():
        return x
      # Use the built in broadcasting of addition.
      zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
      return x + zeros

    # Dynamic broadcast:
    #   Always add to an array of zeros, rather than using a "cond", since a
    #   cond would require copying data from GPU --> CPU.
    special_shape = array_ops.concat((self.batch_shape_dynamic(), [1, 1]), 0)
    zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype)
    return x + zeros
Esempio n. 5
0
  def _reduce_jacobian_det_over_event(
      self, y, ildj, min_event_ndims, event_ndims):
    """Reduce jacobian over event_ndims - min_event_ndims."""
    # In this case, we need to tile the Jacobian over the event and reduce.
    y_rank = array_ops.rank(y)
    y_shape = array_ops.shape(y)[
        y_rank - event_ndims : y_rank - min_event_ndims]

    ones = array_ops.ones(y_shape, ildj.dtype)
    reduced_ildj = math_ops.reduce_sum(
        ones * ildj,
        axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
    # The multiplication by ones can change the inferred static shape so we try
    # to recover as much as possible.
    event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
    if (event_ndims_ is not None and
        y.shape.ndims is not None and
        ildj.shape.ndims is not None):
      y_shape = y.shape[y.shape.ndims - event_ndims_ :
                        y.shape.ndims - min_event_ndims]
      broadcast_shape = array_ops.broadcast_static_shape(ildj.shape, y_shape)
      reduced_ildj.set_shape(
          broadcast_shape[: broadcast_shape.ndims - (
              event_ndims_ - min_event_ndims)])

    return reduced_ildj
Esempio n. 6
0
def _broadcast_shape(shape1, shape2):
  """Convenience function which statically broadcasts shape when possible."""
  if (tensor_util.constant_value(shape1) is not None and
      tensor_util.constant_value(shape2) is not None):
    return array_ops.broadcast_static_shape(
        tensor_shape.TensorShape(tensor_util.constant_value(shape1)),
        tensor_shape.TensorShape(tensor_util.constant_value(shape2)))
  return array_ops.broadcast_dynamic_shape(shape1, shape2)
def _static_check_for_broadcastable_batch_shape(operators):
  """ValueError if operators determined to have non-broadcastable shapes."""
  if len(operators) < 2:
    return

  # This will fail if they cannot be broadcast together.
  batch_shape = operators[0].batch_shape
  for op in operators[1:]:
    batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
 def _finish_log_prob_for_one_fiber(self, y, x, ildj):
   """Finish computation of log_prob on one element of the inverse image."""
   x = self._maybe_rotate_dims(x, rotate_right=True)
   log_prob = self.distribution.log_prob(x)
   if self._is_maybe_event_override:
     log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices)
   log_prob += math_ops.cast(ildj, log_prob.dtype)
   if self._is_maybe_event_override:
     log_prob.set_shape(array_ops.broadcast_static_shape(
         y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape))
   return log_prob
 def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims):
   """Finish computation of prob on one element of the inverse image."""
   x = self._maybe_rotate_dims(x, rotate_right=True)
   prob = self.distribution.prob(x)
   if self._is_maybe_event_override:
     prob = math_ops.reduce_prod(prob, self._reduce_event_indices)
   prob *= math_ops.exp(math_ops.cast(ildj, prob.dtype))
   if self._is_maybe_event_override and isinstance(event_ndims, int):
     prob.set_shape(array_ops.broadcast_static_shape(
         y.get_shape().with_rank_at_least(1)[:-event_ndims], self.batch_shape))
   return prob
 def _prob(self, y):
   x, ildj = self.bijector.inverse_and_inverse_log_det_jacobian(y)
   x = self._maybe_rotate_dims(x, rotate_right=True)
   prob = self.distribution.prob(x)
   if self._is_maybe_event_override:
     prob = math_ops.reduce_prod(prob, self._reduce_event_indices)
   prob *= math_ops.exp(ildj)
   if self._is_maybe_event_override:
     prob.set_shape(array_ops.broadcast_static_shape(
         y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape))
   return prob
 def _log_prob(self, y):
   x = self.bijector.inverse(y)
   ildj = self.bijector.inverse_log_det_jacobian(y)
   x = self._maybe_rotate_dims(x, rotate_right=True)
   log_prob = self.distribution.log_prob(x)
   if self._is_maybe_event_override:
     log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices)
   log_prob = ildj + log_prob
   if self._is_maybe_event_override:
     log_prob.set_shape(array_ops.broadcast_static_shape(
         y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape))
   return log_prob
 def _log_prob(self, y):
   # For caching to work, it is imperative that the bijector is the first to
   # modify the input.
   x = self.bijector.inverse(y)
   ildj = self.bijector.inverse_log_det_jacobian(y)
   x = self._maybe_rotate_dims(x, rotate_right=True)
   log_prob = self.distribution.log_prob(x)
   if self._is_maybe_event_override:
     log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices)
   log_prob = ildj + log_prob
   if self._is_maybe_event_override:
     log_prob.set_shape(array_ops.broadcast_static_shape(
         y.get_shape().with_rank_at_least(1)[:-1], self.batch_shape))
   return log_prob
Esempio n. 13
0
def determine_batch_event_shapes(grid, endpoint_affine):
  """Helper to infer batch_shape and event_shape."""
  with ops.name_scope(name="determine_batch_event_shapes"):
    # grid  # shape: [B, k, q]
    # endpoint_affine     # len=k, shape: [B, d, d]
    batch_shape = grid.shape[:-2]
    batch_shape_tensor = array_ops.shape(grid)[:-2]
    event_shape = None
    event_shape_tensor = None

    def _set_event_shape(shape, shape_tensor):
      if event_shape is None:
        return shape, shape_tensor
      return (array_ops.broadcast_static_shape(event_shape, shape),
              array_ops.broadcast_dynamic_shape(
                  event_shape_tensor, shape_tensor))

    for aff in endpoint_affine:
      if aff.shift is not None:
        batch_shape = array_ops.broadcast_static_shape(
            batch_shape, aff.shift.shape[:-1])
        batch_shape_tensor = array_ops.broadcast_dynamic_shape(
            batch_shape_tensor, array_ops.shape(aff.shift)[:-1])
        event_shape, event_shape_tensor = _set_event_shape(
            aff.shift.shape[-1:], array_ops.shape(aff.shift)[-1:])

      if aff.scale is not None:
        batch_shape = array_ops.broadcast_static_shape(
            batch_shape, aff.scale.batch_shape)
        batch_shape_tensor = array_ops.broadcast_dynamic_shape(
            batch_shape_tensor, aff.scale.batch_shape_tensor())
        event_shape, event_shape_tensor = _set_event_shape(
            tensor_shape.TensorShape([aff.scale.range_dimension]),
            aff.scale.range_dimension_tensor()[array_ops.newaxis])

    return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor
Esempio n. 14
0
def prefer_static_broadcast_shape(
    shape1, shape2, name="prefer_static_broadcast_shape"):
  """Convenience function which statically broadcasts shape when possible.

  Args:
    shape1:  `1-D` integer `Tensor`.  Already converted to tensor!
    shape2:  `1-D` integer `Tensor`.  Already converted to tensor!
    name:  A string name to prepend to created ops.

  Returns:
    The broadcast shape, either as `TensorShape` (if broadcast can be done
      statically), or as a `Tensor`.
  """
  with ops.name_scope(name, values=[shape1, shape2]):
    if (tensor_util.constant_value(shape1) is not None and
        tensor_util.constant_value(shape2) is not None):
      return array_ops.broadcast_static_shape(
          tensor_shape.TensorShape(tensor_util.constant_value(shape1)),
          tensor_shape.TensorShape(tensor_util.constant_value(shape2)))
    return array_ops.broadcast_dynamic_shape(shape1, shape2)
Esempio n. 15
0
def prefer_static_broadcast_shape(
    shape1, shape2, name="prefer_static_broadcast_shape"):
  """Convenience function which statically broadcasts shape when possible.

  Args:
    shape1:  `1-D` integer `Tensor`.  Already converted to tensor!
    shape2:  `1-D` integer `Tensor`.  Already converted to tensor!
    name:  A string name to prepend to created ops.

  Returns:
    The broadcast shape, either as `TensorShape` (if broadcast can be done
      statically), or as a `Tensor`.
  """
  with ops.name_scope(name, values=[shape1, shape2]):
    def make_shape_tensor(x):
      return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32)

    def get_tensor_shape(s):
      if isinstance(s, tensor_shape.TensorShape):
        return s
      s_ = tensor_util.constant_value(make_shape_tensor(s))
      if s_ is not None:
        return tensor_shape.TensorShape(s_)
      return None

    def get_shape_tensor(s):
      if not isinstance(s, tensor_shape.TensorShape):
        return make_shape_tensor(s)
      if s.is_fully_defined():
        return make_shape_tensor(s.as_list())
      raise ValueError("Cannot broadcast from partially "
                       "defined `TensorShape`.")

    shape1_ = get_tensor_shape(shape1)
    shape2_ = get_tensor_shape(shape2)
    if shape1_ is not None and shape2_ is not None:
      return array_ops.broadcast_static_shape(shape1_, shape2_)

    shape1_ = get_shape_tensor(shape1)
    shape2_ = get_shape_tensor(shape2)
    return array_ops.broadcast_dynamic_shape(shape1_, shape2_)
Esempio n. 16
0
def get_broadcast_shape(*tensors):
  """Get broadcast shape as a Python list of integers (preferred) or `Tensor`.

  Args:
    *tensors:  One or more `Tensor` objects (already converted!).

  Returns:
    broadcast shape:  Python list (if shapes determined statically), otherwise
      an `int32` `Tensor`.
  """
  # Try static.
  s_shape = tensors[0].shape
  for t in tensors[1:]:
    s_shape = array_ops.broadcast_static_shape(s_shape, t.shape)
  if s_shape.is_fully_defined():
    return s_shape.as_list()

  # Fallback on dynamic.
  d_shape = array_ops.shape(tensors[0])
  for t in tensors[1:]:
    d_shape = array_ops.broadcast_dynamic_shape(d_shape, array_ops.shape(t))
  return d_shape
Esempio n. 17
0
 def _batch_shape(self):
   return array_ops.broadcast_static_shape(
       self.distribution.batch_shape,
       self.mixture_distribution.logits.shape)[:-1]
Esempio n. 18
0
 def _batch_shape(self):
     return array_ops.broadcast_static_shape(self.low.get_shape(),
                                             self.high.get_shape())
Esempio n. 19
0
 def _get_batch_shape(self):
   return array_ops.broadcast_static_shape(
       array_ops.broadcast_static_shape(
           self.df.get_shape(),
           self.mu.get_shape()),
       self.sigma.get_shape())
Esempio n. 20
0
 def _get_batch_shape(self):
     return array_ops.broadcast_static_shape(
         array_ops.broadcast_static_shape(self.df.get_shape(),
                                          self.mu.get_shape()),
         self.sigma.get_shape())
 def _batch_shape(self):
   return array_ops.broadcast_static_shape(
       self.total_count.get_shape(),
       self.probs.get_shape())
def broadcast_matrix_batch_dims(batch_matrices, name=None):
    """Broadcast leading dimensions of zero or more [batch] matrices.

  Example broadcasting one batch dim of two simple matrices.

  ```python
  x = [[1, 2],
       [3, 4]]  # Shape [2, 2], no batch dims

  y = [[[1]]]   # Shape [1, 1, 1], 1 batch dim of shape [1]

  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])

  x_bc
  ==> [[[1, 2],
        [3, 4]]]  # Shape [1, 2, 2], 1 batch dim of shape [1].

  y_bc
  ==> same as y
  ```

  Example broadcasting many batch dims

  ```python
  x = tf.random.normal(shape=(2, 3, 1, 4, 4))
  y = tf.random.normal(shape=(1, 3, 2, 5, 5))
  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])

  x_bc.shape
  ==> (2, 3, 2, 4, 4)

  y_bc.shape
  ==> (2, 3, 2, 5, 5)
  ```

  Args:
    batch_matrices:  Iterable of `Tensor`s, each having two or more dimensions.
    name:  A string name to prepend to created ops.

  Returns:
    bcast_matrices: List of `Tensor`s, with `bcast_matrices[i]` containing
      the values from `batch_matrices[i]`, with possibly broadcast batch dims.

  Raises:
    ValueError:  If any input `Tensor` is statically determined to have less
      than two dimensions.
  """
    with ops.name_scope(name or "broadcast_matrix_batch_dims",
                        values=batch_matrices):
        check_ops.assert_proper_iterable(batch_matrices)
        batch_matrices = list(batch_matrices)

        for i, mat in enumerate(batch_matrices):
            batch_matrices[i] = ops.convert_to_tensor(mat)
            assert_is_batch_matrix(batch_matrices[i])

        if len(batch_matrices) < 2:
            return batch_matrices

        # Try static broadcasting.
        # bcast_batch_shape is the broadcast batch shape of ALL matrices.
        # E.g. if batch_matrices = [x, y], with
        # x.shape =    [2, j, k]  (batch shape =    [2])
        # y.shape = [3, 1, l, m]  (batch shape = [3, 1])
        # ==> bcast_batch_shape = [3, 2]
        bcast_batch_shape = batch_matrices[0].shape[:-2]
        for mat in batch_matrices[1:]:
            bcast_batch_shape = array_ops.broadcast_static_shape(
                bcast_batch_shape, mat.shape[:-2])
        if bcast_batch_shape.is_fully_defined():
            for i, mat in enumerate(batch_matrices):
                if mat.shape[:-2] != bcast_batch_shape:
                    bcast_shape = array_ops.concat([
                        bcast_batch_shape.as_list(),
                        array_ops.shape(mat)[-2:]
                    ],
                                                   axis=0)
                    batch_matrices[i] = array_ops.broadcast_to(
                        mat, bcast_shape)
            return batch_matrices

        # Since static didn't work, do dynamic, which always copies data.
        bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
        for mat in batch_matrices[1:]:
            bcast_batch_shape = array_ops.broadcast_dynamic_shape(
                bcast_batch_shape,
                array_ops.shape(mat)[:-2])
        for i, mat in enumerate(batch_matrices):
            batch_matrices[i] = array_ops.broadcast_to(
                mat,
                array_ops.concat(
                    [bcast_batch_shape,
                     array_ops.shape(mat)[-2:]], axis=0))

        return batch_matrices
Esempio n. 23
0
 def _batch_shape(self):
   return array_ops.broadcast_static_shape(
       self.concentration.get_shape(),
       self.rate.get_shape())
Esempio n. 24
0
def _kl_brute_force(a, b, name=None):
  """Batched KL divergence `KL(a || b)` for multivariate Normals.

  With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and
  covariance `C_a`, `C_b` respectively,

  ```
  KL(a || b) = 0.5 * ( L - k + T + Q ),
  L := Log[Det(C_b)] - Log[Det(C_a)]
  T := trace(C_b^{-1} C_a),
  Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
  ```

  This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
  methods for solving systems with `C_b` may be available, a dense version of
  (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B`
  is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
  and `y`.

  Args:
    a: Instance of `MultivariateNormalLinearOperator`.
    b: Instance of `MultivariateNormalLinearOperator`.
    name: (optional) name to use for created ops. Default "kl_mvn".

  Returns:
    Batchwise `KL(a || b)`.
  """

  def squared_frobenius_norm(x):
    """Helper to make KL calculation slightly more readable."""
    # http://mathworld.wolfram.com/FrobeniusNorm.html
    # The gradient of KL[p,q] is not defined when p==q. The culprit is
    # linalg_ops.norm, i.e., we cannot use the commented out code.
    # return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1]))
    return math_ops.reduce_sum(math_ops.square(x), axis=[-2, -1])

  # TODO(b/35041439): See also b/35040945. Remove this function once LinOp
  # supports something like:
  #   A.inverse().solve(B).norm(order='fro', axis=[-1, -2])
  def is_diagonal(x):
    """Helper to identify if `LinearOperator` has only a diagonal component."""
    return (isinstance(x, linalg.LinearOperatorIdentity) or
            isinstance(x, linalg.LinearOperatorScaledIdentity) or
            isinstance(x, linalg.LinearOperatorDiag))

  with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc] +
                      a.scale.graph_parents + b.scale.graph_parents):
    # Calculation is based on:
    # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians
    # and,
    # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
    # i.e.,
    #   If Ca = AA', Cb = BB', then
    #   tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
    #                  = tr[inv(B) A A' inv(B)']
    #                  = tr[(inv(B) A) (inv(B) A)']
    #                  = sum_{ij} (inv(B) A)_{ij}**2
    #                  = ||inv(B) A||_F**2
    # where ||.||_F is the Frobenius norm and the second equality follows from
    # the cyclic permutation property.
    if is_diagonal(a.scale) and is_diagonal(b.scale):
      # Using `stddev` because it handles expansion of Identity cases.
      b_inv_a = (a.stddev() / b.stddev())[..., array_ops.newaxis]
    else:
      b_inv_a = b.scale.solve(a.scale.to_dense())
    kl_div = (b.scale.log_abs_determinant()
              - a.scale.log_abs_determinant()
              + 0.5 * (
                  - math_ops.cast(a.scale.domain_dimension_tensor(), a.dtype)
                  + squared_frobenius_norm(b_inv_a)
                  + squared_frobenius_norm(b.scale.solve(
                      (b.mean() - a.mean())[..., array_ops.newaxis]))))
    kl_div.set_shape(array_ops.broadcast_static_shape(
        a.batch_shape, b.batch_shape))
    return kl_div
Esempio n. 25
0
 def _batch_shape(self):
   return array_ops.broadcast_static_shape(
       self.low.get_shape(),
       self.high.get_shape())
Esempio n. 26
0
 def _get_batch_shape(self):
   return array_ops.broadcast_static_shape(
       self.n.get_shape(), self.p.get_shape())
Esempio n. 27
0
 def _batch_shape(self):
   return array_ops.broadcast_static_shape(
       self.total_count.get_shape(),
       self.probs.get_shape())
 def _set_event_shape(shape, shape_tensor):
     if event_shape is None:
         return shape, shape_tensor
     return (array_ops.broadcast_static_shape(event_shape, shape),
             array_ops.broadcast_dynamic_shape(event_shape_tensor,
                                               shape_tensor))
Esempio n. 29
0
 def _get_batch_shape(self):
   return array_ops.broadcast_static_shape(
       self.alpha.get_shape(), self.beta.get_shape())
 def _shape(self):
     # If d_shape = [5, 3], we return [5, 3, 3].
     v_shape = array_ops.broadcast_static_shape(self.row.shape,
                                                self.col.shape)
     return v_shape.concatenate(v_shape[-1:])
def broadcast_matrix_batch_dims(batch_matrices, name=None):
  """Broadcast leading dimensions of zero or more [batch] matrices.

  Example broadcasting one batch dim of two simple matrices.

  ```python
  x = [[1, 2],
       [3, 4]]  # Shape [2, 2], no batch dims

  y = [[[1]]]   # Shape [1, 1, 1], 1 batch dim of shape [1]

  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])

  x_bc
  ==> [[[1, 2],
        [3, 4]]]  # Shape [1, 2, 2], 1 batch dim of shape [1].

  y_bc
  ==> same as y
  ```

  Example broadcasting many batch dims

  ```python
  x = tf.random_normal(shape=(2, 3, 1, 4, 4))
  y = tf.random_normal(shape=(1, 3, 2, 5, 5))
  x_bc, y_bc = broadcast_matrix_batch_dims([x, y])

  x_bc.shape
  ==> (2, 3, 2, 4, 4)

  y_bc.shape
  ==> (2, 3, 2, 5, 5)
  ```

  Args:
    batch_matrices:  Iterable of `Tensor`s, each having two or more dimensions.
    name:  A string name to prepend to created ops.

  Returns:
    bcast_matrices: List of `Tensor`s, with `bcast_matricies[i]` containing
      the values from `batch_matrices[i]`, with possibly broadcast batch dims.

  Raises:
    ValueError:  If any input `Tensor` is statically determined to have less
      than two dimensions.
  """
  with ops.name_scope(
      name or "broadcast_matrix_batch_dims", values=batch_matrices):
    check_ops.assert_proper_iterable(batch_matrices)
    batch_matrices = list(batch_matrices)

    for i, mat in enumerate(batch_matrices):
      batch_matrices[i] = ops.convert_to_tensor(mat)
      assert_is_batch_matrix(batch_matrices[i])

    if len(batch_matrices) < 2:
      return batch_matrices

    # Try static broadcasting.
    # bcast_batch_shape is the broadcast batch shape of ALL matrices.
    # E.g. if batch_matrices = [x, y], with
    # x.shape =    [2, j, k]  (batch shape =    [2])
    # y.shape = [3, 1, l, m]  (batch shape = [3, 1])
    # ==> bcast_batch_shape = [3, 2]
    bcast_batch_shape = batch_matrices[0].get_shape()[:-2]
    for mat in batch_matrices[1:]:
      bcast_batch_shape = array_ops.broadcast_static_shape(
          bcast_batch_shape,
          mat.get_shape()[:-2])
    if bcast_batch_shape.is_fully_defined():
      # The [1, 1] at the end will broadcast with anything.
      bcast_shape = bcast_batch_shape.concatenate([1, 1])
      for i, mat in enumerate(batch_matrices):
        if mat.get_shape()[:-2] != bcast_batch_shape:
          batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)
      return batch_matrices

    # Since static didn't work, do dynamic, which always copies data.
    bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
    for mat in batch_matrices[1:]:
      bcast_batch_shape = array_ops.broadcast_dynamic_shape(
          bcast_batch_shape,
          array_ops.shape(mat)[:-2])
    bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0)
    for i, mat in enumerate(batch_matrices):
      batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)

    return batch_matrices
Esempio n. 32
0
 def _set_event_shape(shape, shape_tensor):
   if event_shape is None:
     return shape, shape_tensor
   return (array_ops.broadcast_static_shape(event_shape, shape),
           array_ops.broadcast_dynamic_shape(
               event_shape_tensor, shape_tensor))
Esempio n. 33
0
 def _batch_shape(self):
     return array_ops.broadcast_static_shape(
         self.distribution.batch_shape,
         self.mixture_distribution.logits.shape)[:-1]
Esempio n. 34
0
 def _shape(self):
     batch_shape = array_ops.broadcast_static_shape(
         self.base_operator.batch_shape,
         self.u.get_shape()[:-2])
     return batch_shape.concatenate(self.base_operator.shape[-2:])
Esempio n. 35
0
 def _get_batch_shape(self):
     return array_ops.broadcast_static_shape(self.n.get_shape(),
                                             self.p.get_shape())
Esempio n. 36
0
 def _get_batch_shape(self):
     return array_ops.broadcast_static_shape(self._sigma.get_shape(),
                                             self._xi.get_shape())
 def _shape(self):
   batch_shape = array_ops.broadcast_static_shape(
       self.base_operator.batch_shape,
       self.u.get_shape()[:-2])
   return batch_shape.concatenate(self.base_operator.shape[-2:])
Esempio n. 38
0
 def _batch_shape(self):
     return array_ops.broadcast_static_shape(self.concentration.get_shape(),
                                             self.rate.get_shape())
Esempio n. 39
0
def _kl_brute_force(a, b, name=None):
    """Batched KL divergence `KL(a || b)` for multivariate Normals.

  With `X`, `Y` both multivariate Normals in `R^k` with means `mu_a`, `mu_b` and
  covariance `C_a`, `C_b` respectively,

  ```
  KL(a || b) = 0.5 * ( L - k + T + Q ),
  L := Log[Det(C_b)] - Log[Det(C_a)]
  T := trace(C_b^{-1} C_a),
  Q := (mu_b - mu_a)^T C_b^{-1} (mu_b - mu_a),
  ```

  This `Op` computes the trace by solving `C_b^{-1} C_a`. Although efficient
  methods for solving systems with `C_b` may be available, a dense version of
  (the square root of) `C_a` is used, so performance is `O(B s k**2)` where `B`
  is the batch size, and `s` is the cost of solving `C_b x = y` for vectors `x`
  and `y`.

  Args:
    a: Instance of `MultivariateNormalLinearOperator`.
    b: Instance of `MultivariateNormalLinearOperator`.
    name: (optional) name to use for created ops. Default "kl_mvn".

  Returns:
    Batchwise `KL(a || b)`.
  """
    def squared_frobenius_norm(x):
        """Helper to make KL calculation slightly more readable."""
        # http://mathworld.wolfram.com/FrobeniusNorm.html
        return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1]))

    # TODO(b/35041439): See also b/35040945. Remove this function once LinOp
    # supports something like:
    #   A.inverse().solve(B).norm(order='fro', axis=[-1, -2])
    def is_diagonal(x):
        """Helper to identify if `LinearOperator` has only a diagonal component."""
        return (isinstance(x, linalg.LinearOperatorIdentity)
                or isinstance(x, linalg.LinearOperatorScaledIdentity)
                or isinstance(x, linalg.LinearOperatorDiag))

    with ops.name_scope(name,
                        "kl_mvn",
                        values=[a.loc, b.loc] + a.scale.graph_parents +
                        b.scale.graph_parents):
        # Calculation is based on:
        # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians
        # and,
        # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
        # i.e.,
        #   If Ca = AA', Cb = BB', then
        #   tr[inv(Cb) Ca] = tr[inv(B)' inv(B) A A']
        #                  = tr[inv(B) A A' inv(B)']
        #                  = tr[(inv(B) A) (inv(B) A)']
        #                  = sum_{ij} (inv(B) A)_{ij}**2
        #                  = ||inv(B) A||_F**2
        # where ||.||_F is the Frobenius norm and the second equality follows from
        # the cyclic permutation property.
        if is_diagonal(a.scale) and is_diagonal(b.scale):
            # Using `stddev` because it handles expansion of Identity cases.
            b_inv_a = (a.stddev() / b.stddev())[..., array_ops.newaxis]
        else:
            b_inv_a = b.scale.solve(a.scale.to_dense())
        kl_div = (b.scale.log_abs_determinant() -
                  a.scale.log_abs_determinant() + 0.5 *
                  (-math_ops.cast(a.scale.domain_dimension_tensor(), a.dtype) +
                   squared_frobenius_norm(b_inv_a) + squared_frobenius_norm(
                       b.scale.solve(
                           (b.mean() - a.mean())[..., array_ops.newaxis]))))
        kl_div.set_shape(
            array_ops.broadcast_static_shape(a.batch_shape, b.batch_shape))
        return kl_div
Esempio n. 40
0
 def _get_batch_shape(self):
   return array_ops.broadcast_static_shape(
       self.alpha.get_shape(), self.beta.get_shape())
Esempio n. 41
0
 def _batch_shape(self):
   return array_ops.broadcast_static_shape(
       self.loc.get_shape(),
       self.scale.get_shape())
Esempio n. 42
0
 def _batch_shape(self):
     return array_ops.broadcast_static_shape(self.loc.get_shape(),
                                             self.scale.get_shape())