def _underdetermined(op, grad):
    """Gradients for the underdetermined case of MatrixSolveLs.

    This is the backprop for the solution to the normal equations of the second
    kind:
      X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B
    that (for lambda=0) solve the least squares problem
      min ||X||_F subject to A*X = B.
    """
    a = op.inputs[0]
    b = op.inputs[1]
    l2_regularizer = op.inputs[2]
    a_shape = array_ops.shape(a)
    batch_shape = a_shape[:-2]
    m = a_shape[-2]

    identity = linalg_ops.eye(m, batch_shape=batch_shape, dtype=a.dtype)
    gramian = math_ops.batch_matmul(
        a, a, adj_y=True) + l2_regularizer * identity
    chol = linalg_ops.cholesky(gramian)
    grad_b = linalg_ops.cholesky_solve(chol, math_ops.batch_matmul(a, grad))
    # Temporary z = (A * A^T + lambda * I)^{-1} * B.
    z = linalg_ops.cholesky_solve(chol, b)
    bz = -math_ops.batch_matmul(grad_b, z, adj_y=True)
    bz_sym = bz + array_ops.matrix_transpose(bz)
    grad_a = math_ops.batch_matmul(bz_sym, a) + math_ops.batch_matmul(z, grad)
    return (grad_a, grad_b, None)
Exemple #2
0
def _BatchMatrixInverseGrad(op, grad):
  """Gradient for BatchMatrixInverse."""
  ainv = op.outputs[0]
  return -math_ops.batch_matmul(
      ainv,
      math_ops.batch_matmul(grad, ainv, adj_y=True),
      adj_x=True)
 def matmul(self, x, name='matmul'):
   """Left (batch) matrix multiplication of `x` by this operator."""
   chol = self._chol
   with ops.name_scope(self.name):
     with ops.op_scope(self.inputs, name):
       a_times_x = math_ops.batch_matmul(chol, x, adj_x=True)
       return math_ops.batch_matmul(chol, a_times_x)
  def _batch_sqrt_solve(self, rhs):
    # Recall the square root of this operator is M + VDV^T.
    # The Woodbury formula gives:
    # (M + VDV^T)^{-1}
    # = M^{-1} - M^{-1} V (D^{-1} + V^T M^{-1} V)^{-1} V^T M^{-1}
    # = M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
    # where C is the capacitance matrix.
    m = self._operator
    v = self._v
    cchol = self._chol_capacitance(batch_mode=True)

    # The operators will use batch/singleton mode automatically.  We don't
    # override.
    # M^{-1} rhs
    minv_rhs = m.solve(rhs)
    # V^T M^{-1} rhs
    vt_minv_rhs = math_ops.batch_matmul(v, minv_rhs, adj_x=True)
    # C^{-1} V^T M^{-1} rhs
    cinv_vt_minv_rhs = linalg_ops.batch_cholesky_solve(cchol, vt_minv_rhs)
    # V C^{-1} V^T M^{-1} rhs
    v_cinv_vt_minv_rhs = math_ops.batch_matmul(v, cinv_vt_minv_rhs)
    # M^{-1} V C^{-1} V^T M^{-1} rhs
    minv_v_cinv_vt_minv_rhs = m.solve(v_cinv_vt_minv_rhs)

    # M^{-1} - M^{-1} V C^{-1} V^T M^{-1}
    return minv_rhs - minv_v_cinv_vt_minv_rhs
Exemple #5
0
  def _overdetermined(op, grad):
    """Gradients for the overdetermined case of MatrixSolveLs.

    This is the backprop for the solution to the normal equations of the first
    kind:
       X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B
    which solve the least squares problem
       min ||A * X - B||_F^2 + lambda ||X||_F^2.
    """
    a = op.inputs[0]
    b = op.inputs[1]
    l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
    x = op.outputs[0]
    a_shape = array_ops.shape(a)
    batch_shape = a_shape[:-2]
    n = a_shape[-1]

    identity = linalg_ops.eye(n, batch_shape=batch_shape, dtype=a.dtype)
    gramian = math_ops.batch_matmul(
        a, a, adj_x=True) + l2_regularizer * identity
    chol = linalg_ops.cholesky(gramian)
    # Temporary z = (A^T * A + lambda * I)^{-1} * grad.
    z = linalg_ops.cholesky_solve(chol, grad)
    xzt = math_ops.batch_matmul(x, z, adj_y=True)
    zx_sym = xzt + array_ops.matrix_transpose(xzt)
    grad_a = -math_ops.batch_matmul(a, zx_sym) + math_ops.batch_matmul(
        b, z, adj_y=True)
    grad_b = math_ops.batch_matmul(a, z)
    return (grad_a, grad_b, None)
 def matmul(self, x, name='matmul'):
     """Left (batch) matrix multiplication of `x` by this operator."""
     chol = self._chol
     with ops.name_scope(self.name):
         with ops.op_scope(self.inputs, name):
             a_times_x = math_ops.batch_matmul(chol, x, adj_x=True)
             return math_ops.batch_matmul(chol, a_times_x)
Exemple #7
0
def _BatchMatrixInverseGrad(op, grad):
  """Gradient for BatchMatrixInverse."""
  ainv = op.outputs[0]
  return -math_ops.batch_matmul(ainv,
                                math_ops.batch_matmul(grad,
                                                      ainv,
                                                      adj_y=True),
                                adj_x=True)
 def _batch_matmul(self, x, transpose_x=False):
     # tf.batch_matmul is defined x * y, so "y" is on the right, not "x".
     chol = array_ops.batch_matrix_band_part(self._chol, -1, 0)
     chol_times_x = math_ops.batch_matmul(chol,
                                          x,
                                          adj_x=True,
                                          adj_y=transpose_x)
     return math_ops.batch_matmul(chol, chol_times_x)
Exemple #9
0
def _BatchMatrixSolveGrad(op, grad):
    """Gradient for BatchMatrixSolve."""
    a = op.inputs[0]
    adjoint_a = op.get_attr("adjoint")
    c = op.outputs[0]
    grad_b = linalg_ops.batch_matrix_solve(a, grad, adjoint=not adjoint_a)
    if adjoint_a:
        grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True)
    else:
        grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
    return (grad_a, grad_b)
Exemple #10
0
def _BatchMatrixSolveGrad(op, grad):
  """Gradient for BatchMatrixSolve."""
  a = op.inputs[0]
  c = op.outputs[0]
  # TODO(rmlarsen): Replace the following two lines with
  # a single call to batch_matrix_solve after adding
  # in an option to solve for A^T X = Y.
  ainv = linalg_ops.batch_matrix_inverse(a)
  grad_b = math_ops.batch_matmul(ainv, grad, adj_x=True)
  grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
  return (grad_a, grad_b)
Exemple #11
0
def _BatchMatrixSolveGrad(op, grad):
    """Gradient for BatchMatrixSolve."""
    a = op.inputs[0]
    c = op.outputs[0]
    # TODO(rmlarsen): Replace the following two lines with
    # a single call to batch_matrix_solve after adding
    # in an option to solve for A^T X = Y.
    ainv = linalg_ops.batch_matrix_inverse(a)
    grad_b = math_ops.batch_matmul(ainv, grad, adj_x=True)
    grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
    return (grad_a, grad_b)
Exemple #12
0
def _MatrixSolveGrad(op, grad):
    """Gradient for MatrixSolve."""
    a = op.inputs[0]
    adjoint_a = op.get_attr("adjoint")
    c = op.outputs[0]
    grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a)
    if adjoint_a:
        grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True)
    else:
        grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
    return (grad_a, grad_b)
Exemple #13
0
def _test1(op, grad_e, grad_v):
    """Gradient for SelfAdjointEigV2 derived with Joan with no adjustment for subspace"""
    e = op.outputs[0]
    v = op.outputs[1]
    #dim = v.get_shape()
    with ops.control_dependencies([grad_e.op, grad_v.op]):
        if grad_v is not None:  
            E = array_ops.diag(e)
            v_proj = array_ops.slice(v, [0,0], [20,2])
            grad_grassman = grad_v - math_ops.batch_matmul(math_ops.batch_matmul(v_proj, array_ops.transpose(v_proj)), grad_v)
            grad_a = math_ops.batch_matmul(grad_grassman, math_ops.batch_matmul(E, array_ops.transpose(grad_v)))+math_ops.batch_matmul(grad_v, math_ops.batch_matmul(E, array_ops.transpose(grad_grassman)))
    return grad_a
  def _batch_sqrt_matmul(self, x, transpose_x=False):
    v = self._v
    m = self._operator
    d = self._diag_operator
    # The operators call the appropriate matmul/batch_matmul automatically.  We
    # cannot override.
    # batch_matmul is defined as:  x * y, so adj_x and adj_y are the ways to
    # transpose the left and right.
    mx = m.matmul(x, transpose_x=transpose_x)
    vt_x = math_ops.batch_matmul(v, x, adj_x=True, adj_y=transpose_x)
    d_vt_x = d.matmul(vt_x)
    v_d_vt_x = math_ops.batch_matmul(v, d_vt_x)

    return mx + v_d_vt_x
Exemple #15
0
 def _variance(self):
   p = self.p * array_ops.expand_dims(array_ops.ones_like(self.n), -1)
   outer_prod = math_ops.batch_matmul(
       array_ops.expand_dims(self._mean_val, -1),
       array_ops.expand_dims(p, -2))
   return array_ops.batch_matrix_set_diag(
       -outer_prod, self._mean_val - self._mean_val * p)
Exemple #16
0
    def __call__(self, inputs, state, scope=None):
        state, fw = state
        with vs.variable_scope(scope or type(self).__name__) as scope:
            """Wh(t) + Cx(t)"""
            linear = self.fw_calc([state, inputs], self._hidden_units, False)
            """h_0(t+1) = f(Wh(t) + Cx(t))"""
            if not self._norm_re:
                h = self._activation(self._norm(linear, scope="Norm0"))
            else:
                h = self._activation(self._norm(linear))
            h = self._vec2mat(h)
            linear = self._vec2mat(linear)
            for i in range(self._S):
                """
        h_{s+1}(t+1) = f([Wh(t) + Cx(t)] + A(t) h_s(t+1)), S times.
        From Eqn (2).
        """
                if not self._norm_re:
                    h = self._activation(
                        self._norm(linear + tf.matmul(fw, h),
                                   scope="Norm%d" % (i + 1)))
                else:
                    h = self._activation(
                        self._norm(linear + math_ops.batch_matmul(fw, h)))
            """
      Compute A(t+1)  according to Eqn (4)
      """
            state = self._vec2mat(state)
            new_fw = self._lambda * fw + self._eta * tf.matmul(
                state, state, adjoint_b=True)

            h = self._mat2vec(h)

            return h, (h, new_fw)
Exemple #17
0
    def variance(self, name="variance"):
        """Variance of the Wishart distribution.

    This function should not be confused with the covariance of the Wishart. The
    covariance matrix would have shape `q x q` where,
    `q = dimension * (dimension+1) / 2`
    and having elements corresponding to some mapping from a lower-triangular
    matrix to a vector-space.

    This function returns the diagonal of the Covariance matrix but shaped
    as a `dimension x dimension` matrix.

    Args:
      name: The name of this op.

    Returns:
      variance: `Tensor` of dtype `self.dtype`.
    """
        with ops.name_scope(self.name):
            with ops.name_scope(name, values=list(self.inputs.values())):
                x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense()
                d = array_ops.expand_dims(array_ops.batch_matrix_diag_part(x), -1)
                v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True)
                if self.cholesky_input_output_matrices:
                    return linalg_ops.batch_cholesky(v)
                else:
                    return v
Exemple #18
0
  def variance(self, name='variance'):
    """Variance of the Wishart distribution.

    This function should not be confused with the covariance of the Wishart. The
    covariance matrix would have shape `q x q` where,
    `q = dimension * (dimension+1) / 2`
    and having elements corresponding to some mapping from a lower-triangular
    matrix to a vector-space.

    This function returns the diagonal of the Covariance matrix but shaped
    as a `dimension x dimension` matrix.

    Args:
      name: The name of this op.

    Returns:
      variance: `Tensor` of dtype `self.dtype`.
    """
    with ops.name_scope(self.name):
      with ops.name_scope(name, values=list(self.inputs.values())):
        x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense()
        d = array_ops.expand_dims(array_ops.batch_matrix_diag_part(x), -1)
        v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True)
        if self.cholesky_input_output_matrices:
          return linalg_ops.batch_cholesky(v)
        else:
          return v
Exemple #19
0
def _BatchMatrixSolveGrad(op, grad):
  """Gradient for BatchMatrixSolve."""
  a = op.inputs[0]
  c = op.outputs[0]
  grad_b = linalg_ops.batch_matrix_solve(a, grad, adjoint=True)
  grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
  return (grad_a, grad_b)
Exemple #20
0
def _MatrixTriangularSolveGrad(op, grad):
    """Gradient for MatrixTriangularSolve."""
    a = op.inputs[0]
    adjoint_a = op.get_attr("adjoint")
    lower_a = op.get_attr("lower")
    c = op.outputs[0]
    grad_b = linalg_ops.matrix_triangular_solve(a, grad, lower=lower_a, adjoint=not adjoint_a)
    if adjoint_a:
        grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True)
    else:
        grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
    if lower_a:
        grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
    else:
        grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
    return (grad_a, grad_b)
Exemple #21
0
 def _forward(self, x):
     x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x)
     x = math_ops.batch_matmul(self.scale, x)
     x = self.shaper.undo_make_batch_of_event_sample_matrices(
         x, sample_shape)
     x += self.loc
     return x
Exemple #22
0
 def _variance(self):
     x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense()
     d = array_ops.expand_dims(array_ops.matrix_diag_part(x), -1)
     v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True)
     if self.cholesky_input_output_matrices:
         return linalg_ops.cholesky(v)
     return v
Exemple #23
0
 def _variance(self):
     scale = self.alpha_sum * math_ops.sqrt(1.0 + self.alpha_sum)
     alpha = self.alpha / scale
     outer_prod = -math_ops.batch_matmul(
         array_ops.expand_dims(alpha, dim=-1), array_ops.expand_dims(alpha, dim=-2)  # column
     )  # row
     return array_ops.batch_matrix_set_diag(outer_prod, alpha * (self.alpha_sum / scale - alpha))
Exemple #24
0
 def _variance(self):
   x = math_ops.sqrt(self.df) * self.scale_operator_pd.to_dense()
   d = array_ops.expand_dims(array_ops.matrix_diag_part(x), -1)
   v = math_ops.square(x) + math_ops.batch_matmul(d, d, adj_y=True)
   if self.cholesky_input_output_matrices:
     return linalg_ops.cholesky(v)
   return v
Exemple #25
0
def _MatrixTriangularSolveGrad(op, grad):
  """Gradient for MatrixTriangularSolve."""
  a = op.inputs[0]
  adjoint_a = op.get_attr("adjoint")
  lower_a = op.get_attr("lower")
  c = op.outputs[0]
  grad_b = linalg_ops.matrix_triangular_solve(
      a, grad, lower=lower_a, adjoint=not adjoint_a)
  if adjoint_a:
    grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True)
  else:
    grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
  if lower_a:
    grad_a = array_ops.matrix_band_part(grad_a, -1, 0)
  else:
    grad_a = array_ops.matrix_band_part(grad_a, 0, -1)
  return (grad_a, grad_b)
Exemple #26
0
 def _variance(self):
     scale = self.alpha_sum * math_ops.sqrt(1. + self.alpha_sum)
     alpha = self.alpha / scale
     outer_prod = -math_ops.batch_matmul(
         array_ops.expand_dims(alpha, dim=-1),  # column
         array_ops.expand_dims(alpha, dim=-2))  # row
     return array_ops.batch_matrix_set_diag(
         outer_prod, alpha * (self.alpha_sum / scale - alpha))
Exemple #27
0
  def _sample_n(self, n, seed):
    batch_shape = self.batch_shape()
    event_shape = self.event_shape()
    batch_ndims = array_ops.shape(batch_shape)[0]

    ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
    shape = array_ops.concat(0, ((n,), batch_shape, event_shape))

    # Complexity: O(nbk^2)
    x = random_ops.random_normal(shape=shape,
                                 mean=0.,
                                 stddev=1.,
                                 dtype=self.dtype,
                                 seed=seed)

    # Complexity: O(nbk)
    # This parametrization is equivalent to Chi2, i.e.,
    # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
    g = random_ops.random_gamma(shape=(n,),
                                alpha=self._multi_gamma_sequence(
                                    0.5 * self.df, self.dimension),
                                beta=0.5,
                                dtype=self.dtype,
                                seed=distribution_util.gen_new_seed(
                                    seed, "wishart"))

    # Complexity: O(nbk^2)
    x = array_ops.matrix_band_part(x, -1, 0)  # Tri-lower.

    # Complexity: O(nbk)
    x = array_ops.matrix_set_diag(x, math_ops.sqrt(g))

    # Make batch-op ready.
    # Complexity: O(nbk^2)
    perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,)))
    x = array_ops.transpose(x, perm)
    shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1)))
    x = array_ops.reshape(x, shape)

    # Complexity: O(nbM) where M is the complexity of the operator solving a
    # vector system.  E.g., for OperatorPDDiag, each matmul is O(k^2), so
    # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is
    # O(k^3) so this step has complexity O(nbk^3).
    x = self.scale_operator_pd.sqrt_matmul(x)

    # Undo make batch-op ready.
    # Complexity: O(nbk^2)
    shape = array_ops.concat(0, (batch_shape, event_shape, (n,)))
    x = array_ops.reshape(x, shape)
    perm = array_ops.concat(0, ((ndims-1,), math_ops.range(0, ndims-1)))
    x = array_ops.transpose(x, perm)

    if not self.cholesky_input_output_matrices:
      # Complexity: O(nbk^3)
      x = math_ops.batch_matmul(x, x, adj_y=True)

    return x
Exemple #28
0
    def _sample_n(self, n, seed):
        batch_shape = self.batch_shape()
        event_shape = self.event_shape()
        batch_ndims = array_ops.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = array_ops.concat(0, ((n, ), batch_shape, event_shape))

        # Complexity: O(nbk^2)
        x = random_ops.random_normal(shape=shape,
                                     mean=0.,
                                     stddev=1.,
                                     dtype=self.dtype,
                                     seed=seed)

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        g = random_ops.random_gamma(
            shape=(n, ),
            alpha=self._multi_gamma_sequence(0.5 * self.df, self.dimension),
            beta=0.5,
            dtype=self.dtype,
            seed=distribution_util.gen_new_seed(seed, "wishart"))

        # Complexity: O(nbk^2)
        x = array_ops.matrix_band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = array_ops.matrix_set_diag(x, math_ops.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk^2)
        perm = array_ops.concat(0, (math_ops.range(1, ndims), (0, )))
        x = array_ops.transpose(x, perm)
        shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1)))
        x = array_ops.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system.  E.g., for OperatorPDDiag, each matmul is O(k^2), so
        # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is
        # O(k^3) so this step has complexity O(nbk^3).
        x = self.scale_operator_pd.sqrt_matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk^2)
        shape = array_ops.concat(0, (batch_shape, event_shape, (n, )))
        x = array_ops.reshape(x, shape)
        perm = array_ops.concat(0,
                                ((ndims - 1, ), math_ops.range(0, ndims - 1)))
        x = array_ops.transpose(x, perm)

        if not self.cholesky_input_output_matrices:
            # Complexity: O(nbk^3)
            x = math_ops.batch_matmul(x, x, adj_y=True)

        return x
  def _sqrt_to_dense(self):
    v = self._v
    d = self._diag_operator
    m = self._operator

    d_vt = d.matmul(v, transpose_x=True)
    # Batch op won't be efficient for singletons.  Currently we don't break
    # to_dense into batch/singleton methods.
    v_d_vt = math_ops.batch_matmul(v, d_vt)
    m_plus_v_d_vt = m.to_dense() + v_d_vt
    return m_plus_v_d_vt
Exemple #30
0
 def variance(self, name="variance"):
   """Variance of the distribution."""
   with ops.name_scope(self.name):
     with ops.op_scope([self._n, self._p, self._mean], name):
       p = array_ops.expand_dims(
           self._p * array_ops.expand_dims(
               array_ops.ones_like(self._n), -1), -1)
       variance = -math_ops.batch_matmul(
           array_ops.expand_dims(self._mean, -1), p, adj_y=True)
       variance += array_ops.batch_matrix_diag(self._mean)
       return variance
Exemple #31
0
 def variance(self, name="variance"):
   """Variance of the distribution."""
   with ops.name_scope(self.name):
     with ops.name_scope(name, values=[self._n, self._p, self._mean]):
       p = array_ops.expand_dims(
           self._p * array_ops.expand_dims(
               array_ops.ones_like(self._n), -1), -1)
       variance = -math_ops.batch_matmul(
           array_ops.expand_dims(self._mean, -1), p, adj_y=True)
       variance += array_ops.batch_matrix_diag(self._mean)
       return variance
 def _variance(self):
   alpha_sum = array_ops.expand_dims(self.alpha_sum, -1)
   normalized_alpha = self.alpha / alpha_sum
   variance = -math_ops.batch_matmul(
       array_ops.expand_dims(normalized_alpha, -1),
       array_ops.expand_dims(normalized_alpha, -2))
   variance = array_ops.matrix_set_diag(variance, normalized_alpha *
                                        (1. - normalized_alpha))
   shared_factor = (self.n * (alpha_sum + self.n) /
                    (alpha_sum + 1) * array_ops.ones_like(self.alpha))
   variance *= array_ops.expand_dims(shared_factor, -1)
   return variance
 def _variance(self):
     alpha_sum = array_ops.expand_dims(self.alpha_sum, -1)
     normalized_alpha = self.alpha / alpha_sum
     variance = -math_ops.batch_matmul(
         array_ops.expand_dims(normalized_alpha, -1),
         array_ops.expand_dims(normalized_alpha, -2))
     variance = array_ops.batch_matrix_set_diag(
         variance, normalized_alpha * (1. - normalized_alpha))
     shared_factor = (self.n * (alpha_sum + self.n) / (alpha_sum + 1) *
                      array_ops.ones_like(self.alpha))
     variance *= array_ops.expand_dims(shared_factor, -1)
     return variance
Exemple #34
0
  def _underdetermined(op, grad):
    """Gradients for the underdetermined case of MatrixSolveLs.

    This is the backprop for the solution to the normal equations of the second
    kind:
      X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B
    that (for lambda=0) solve the least squares problem
      min ||X||_F subject to A*X = B.
    """
    a = op.inputs[0]
    b = op.inputs[1]
    l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype)
    a_shape = array_ops.shape(a)
    batch_shape = a_shape[:-2]
    m = a_shape[-2]

    identity = linalg_ops.eye(m, batch_shape=batch_shape, dtype=a.dtype)
    gramian = math_ops.batch_matmul(
        a, a, adj_y=True) + l2_regularizer * identity
    chol = linalg_ops.cholesky(gramian)
    grad_b = linalg_ops.cholesky_solve(chol, math_ops.batch_matmul(a, grad))
    # Temporary tmp = (A * A^T + lambda * I)^{-1} * B.
    tmp = linalg_ops.cholesky_solve(chol, b)
    a1 = math_ops.batch_matmul(tmp, a, adj_x=True)
    a1 = -math_ops.batch_matmul(grad_b, a1)
    a2 = grad - math_ops.batch_matmul(a, grad_b, adj_x=True)
    a2 = math_ops.batch_matmul(tmp, a2, adj_y=True)
    grad_a = a1 + a2
    return (grad_a, grad_b, None)
Exemple #35
0
    def sample(self, n, seed=None, name=None):
        """Sample `n` observations from the Multivariate Normal Distributions.

    Args:
      n: `Scalar`, type int32, the number of observations to sample.
      seed: Python integer, the random seed.
      name: The name to give this op.

    Returns:
      samples: `[n, ...]`, a `Tensor` of `n` samples for each
        of the distributions determined by broadcasting the hyperparameters.
    """
        with ops.op_scope([self._mu, self._sigma_chol, n], name,
                          "MultivariateNormalSample"):
            # TODO(ebrevdo): Is there a better way to get broadcast_shape?
            broadcast_shape = self.mu.get_shape()
            n = ops.convert_to_tensor(n)
            sigma_shape_left = array_ops.slice(
                array_ops.shape(self._sigma_chol), [0],
                array_ops.pack([array_ops.rank(self._sigma_chol) - 2]))

            k_n = array_ops.pack([self._k, n])
            shape = array_ops.concat(0, [sigma_shape_left, k_n])
            white_samples = random_ops.random_normal(shape=shape,
                                                     mean=0,
                                                     stddev=1,
                                                     dtype=self._mu.dtype,
                                                     seed=seed)

            correlated_samples = math_ops.batch_matmul(self._sigma_chol,
                                                       white_samples)

            # Move the last dimension to the front
            perm = array_ops.concat(
                0, (array_ops.pack([array_ops.rank(correlated_samples) - 1]),
                    math_ops.range(0,
                                   array_ops.rank(correlated_samples) - 1)))

            # TODO(ebrevdo): Once we get a proper tensor contraction op,
            # perform the inner product using that instead of batch_matmul
            # and this slow transpose can go away!
            correlated_samples = array_ops.transpose(correlated_samples, perm)

            samples = correlated_samples + self.mu

            # Provide some hints to shape inference
            n_val = tensor_util.constant_value(n)
            final_shape = tensor_shape.vector(n_val).concatenate(
                broadcast_shape)
            samples.set_shape(final_shape)

            return samples
Exemple #36
0
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
    """Gradient for SelfAdjointEigV2."""
    e = op.outputs[0]
    v = op.outputs[1]
    # a = op.inputs[0], which satisfies
    # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
    with ops.control_dependencies([grad_e.op, grad_v.op]):
        if grad_v is not None:
            # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
            # Notice that because of the term involving f, the gradient becomes
            # infinite (or NaN in practice) when eigenvalues are not unique.
            # Mathematically this should not be surprising, since for (k-fold)
            # degenerate eigenvalues, the corresponding eigenvectors are only defined
            # up to arbitrary rotation in a (k-dimensional) subspace.
            f = array_ops.matrix_set_diag(
                math_ops.inv(
                    array_ops.expand_dims(e, -2) -
                    array_ops.expand_dims(e, -1)), array_ops.zeros_like(e))
            grad_a = math_ops.batch_matmul(
                v,
                math_ops.batch_matmul(
                    array_ops.matrix_diag(grad_e) +
                    f * math_ops.batch_matmul(v, grad_v, adj_x=True),
                    v,
                    adj_y=True))
        else:
            grad_a = math_ops.batch_matmul(
                v,
                math_ops.batch_matmul(array_ops.matrix_diag(grad_e),
                                      v,
                                      adj_y=True))
        # The forward op only depends on the lower triangular part of a, so here we
        # symmetrize and take the lower triangle
        grad_a = array_ops.matrix_band_part(
            grad_a + array_ops.matrix_transpose(grad_a), -1, 0)
        grad_a = array_ops.matrix_set_diag(
            grad_a, 0.5 * array_ops.matrix_diag_part(grad_a))
        return grad_a
Exemple #37
0
  def variance(self, name="variance"):
    """Variance of the distribution."""
    with ops.name_scope(self.name):
      with ops.name_scope(name, values=[self._alpha, self._alpha_0]):
        alpha = array_ops.expand_dims(self._alpha, -1)
        alpha_0 = array_ops.expand_dims(self._alpha_0, -1)

        expanded_alpha_0 = array_ops.expand_dims(alpha_0, -1)

        variance = -math_ops.batch_matmul(alpha, alpha, adj_y=True) / (
            expanded_alpha_0 ** 2 * (expanded_alpha_0 + 1))
        diagonal = self._alpha / (alpha_0 * (alpha_0 + 1))
        variance += array_ops.batch_matrix_diag(diagonal)
        return variance
Exemple #38
0
  def variance(self, name="variance"):
    """Variance of the distribution."""
    with ops.name_scope(self.name):
      with ops.op_scope([self._alpha, self._alpha_0], name):
        alpha = array_ops.expand_dims(self._alpha, -1)
        alpha_0 = array_ops.expand_dims(self._alpha_0, -1)

        expanded_alpha_0 = array_ops.expand_dims(alpha_0, -1)

        variance = -math_ops.batch_matmul(alpha, alpha, adj_y=True) / (
            expanded_alpha_0 ** 2 * (expanded_alpha_0 + 1))
        diagonal = self._alpha / (alpha_0 * (alpha_0 + 1))
        variance += array_ops.batch_matrix_diag(diagonal)
        return variance
def _SelfAdjointEigV2Grad(op, grad_e, grad_v):
  """Gradient for SelfAdjointEigV2."""
  e = op.outputs[0]
  v = op.outputs[1]
  # a = op.inputs[0], which satisfies
  # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i]
  with ops.control_dependencies([grad_e.op, grad_v.op]):
    if grad_v is not None:
      # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0).
      # Notice that because of the term involving f, the gradient becomes
      # infinite (or NaN in practice) when eigenvalues are not unique.
      # Mathematically this should not be surprising, since for (k-fold)
      # degenerate eigenvalues, the corresponding eigenvectors are only defined
      # up to arbitrary rotation in a (k-dimensional) subspace.
      f = array_ops.matrix_set_diag(
          math_ops.inv(
              array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)),
          array_ops.zeros_like(e))
      grad_a = math_ops.batch_matmul(
          v,
          math_ops.batch_matmul(
              array_ops.matrix_diag(grad_e) + f * math_ops.batch_matmul(
                  v, grad_v, adj_x=True),
              v,
              adj_y=True))
    else:
      grad_a = math_ops.batch_matmul(
          v,
          math_ops.batch_matmul(
              array_ops.matrix_diag(grad_e), v, adj_y=True))
    # The forward op only depends on the lower triangular part of a, so here we
    # symmetrize and take the lower triangle
    grad_a = array_ops.matrix_band_part(
        grad_a + array_ops.matrix_transpose(grad_a), -1, 0)
    grad_a = array_ops.matrix_set_diag(grad_a, 0.5 *
                                       array_ops.matrix_diag_part(grad_a))
    return grad_a
  def sqrt_matmul(self, x, name='sqrt_matmul'):
    """Left (batch) matmul `x` by a sqrt of this matrix:  `Sx` where `A = S S^T.

    Args:
      x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype`
        as self.
      name:  A name scope to use for ops added by this method.

    Returns:
      Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`.
    """
    with ops.name_scope(self.name):
      with ops.op_scope([x] + self.inputs, name):
        chol_lower = array_ops.batch_matrix_band_part(self._chol, -1, 0)
        return math_ops.batch_matmul(chol_lower, x)
Exemple #41
0
    def __call__(self, inputs, state, scope=None):
        state, fast_weights = state
        with vs.variable_scope(scope or type(self).__name__) as scope:
            """Compute Wh(t) + Cx(t)"""
            linear = self._fwlinear([state, inputs], self._num_units, False)
            """Compute h_0(t+1) = f(Wh(t) + Cx(t))"""
            if not self._reuse_norm:
                h = self._activation(self._norm(linear, scope="Norm0"))
            else:
                h = self._activation(self._norm(linear))
            h = self._vector2matrix(h)
            linear = self._vector2matrix(linear)
            for i in range(self._S):
                """
        Compute h_{s+1}(t+1) = f([Wh(t) + Cx(t)] + A(t) h_s(t+1)), S times.
        See Eqn (2) in the paper.
        """
                if not self._reuse_norm:
                    h = self._activation(
                        self._norm(linear +
                                   math_ops.batch_matmul(fast_weights, h),
                                   scope="Norm%d" % (i + 1)))
                else:
                    h = self._activation(
                        self._norm(linear +
                                   math_ops.batch_matmul(fast_weights, h)))
            """
      Compute A(t+1)  according to Eqn (4)
      """
            state = self._vector2matrix(state)
            new_fast_weights = self._lambda * fast_weights + self._eta * math_ops.batch_matmul(
                state, state, adj_y=True)

            h = self._matrix2vector(h)

            return h, (h, new_fast_weights)
Exemple #42
0
  def sample(self, n, seed=None, name=None):
    """Sample `n` observations from the Multivariate Normal Distributions.

    Args:
      n: `Scalar`, type int32, the number of observations to sample.
      seed: Python integer, the random seed.
      name: The name to give this op.

    Returns:
      samples: `[n, ...]`, a `Tensor` of `n` samples for each
        of the distributions determined by broadcasting the hyperparameters.
    """
    with ops.op_scope(
        [self._mu, self._sigma_chol, n], name, "MultivariateNormalSample"):
      # TODO(ebrevdo): Is there a better way to get broadcast_shape?
      broadcast_shape = self.mu.get_shape()
      n = ops.convert_to_tensor(n)
      sigma_shape_left = array_ops.slice(
          array_ops.shape(self._sigma_chol),
          [0], array_ops.pack([array_ops.rank(self._sigma_chol) - 2]))

      k_n = array_ops.pack([self._k, n])
      shape = array_ops.concat(0, [sigma_shape_left, k_n])
      white_samples = random_ops.random_normal(
          shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)

      correlated_samples = math_ops.batch_matmul(
          self._sigma_chol, white_samples)

      # Move the last dimension to the front
      perm = array_ops.concat(
          0,
          (array_ops.pack([array_ops.rank(correlated_samples) - 1]),
           math_ops.range(0, array_ops.rank(correlated_samples) - 1)))

      # TODO(ebrevdo): Once we get a proper tensor contraction op,
      # perform the inner product using that instead of batch_matmul
      # and this slow transpose can go away!
      correlated_samples = array_ops.transpose(correlated_samples, perm)

      samples = correlated_samples + self.mu

      # Provide some hints to shape inference
      n_val = tensor_util.constant_value(n)
      final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape)
      samples.set_shape(final_shape)

      return samples
    def sqrt_matmul(self, x, name='sqrt_matmul'):
        """Left (batch) matmul `x` by a sqrt of this matrix:  `Sx` where `A = S S^T.

    Args:
      x: `Tensor` with shape broadcastable to `[N1,...,Nb, k]` and same `dtype`
        as self.
      name:  A name scope to use for ops added by this method.

    Returns:
      Shape `[N1,...,Nb, k]` `Tensor` holding the product `S x`.
    """
        with ops.name_scope(self.name):
            with ops.op_scope([x] + self.inputs, name):
                chol_lower = array_ops.batch_matrix_band_part(
                    self._chol, -1, 0)
                return math_ops.batch_matmul(chol_lower, x)
  def _chol_capacitance(self, batch_mode):
    """Cholesky factorization of the capacitance term."""
    # Cholesky factor for (D^{-1} + V^T M^{-1} V), which is sometimes
    # known as the "capacitance" matrix.

    # self._operator will use batch if need be. Automatically.  We cannot force
    # that here.
    # M^{-1} V
    minv_v = self._operator.solve(self._v)
    # V^T M^{-1} V
    if batch_mode:
      vt_minv_v = math_ops.batch_matmul(self._v, minv_v, adj_x=True)
    else:
      vt_minv_v = math_ops.matmul(self._v, minv_v, transpose_a=True)

    # D^{-1} + V^T M^{-1} V
    capacitance = self._diag_inv_operator.add_to_tensor(vt_minv_v)
    # Cholesky[D^{-1} + V^T M^{-1} V]
    return linalg_ops.cholesky(capacitance)
Exemple #45
0
    def variance(self, name="mean"):
        """Class variances for every batch member.

    The variance for each batch member is defined as the following:

    ```
    Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
      (n + alpha_0) / (1 + alpha_0)
    ```

    where `alpha_0 = sum_j alpha_j`.

    The covariance between elements in a batch is defined as:

    ```
    Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 *
      (n + alpha_0) / (1 + alpha_0)
    ```

    Args:
      name: The name for this op.

    Returns:
      A `Tensor` representing the variances for each batch member.
    """
        alpha = self._alpha
        alpha_sum = self._alpha_sum
        n = self._n
        with ops.name_scope(self.name):
            with ops.name_scope(name, values=[alpha, alpha_sum, n]):
                expanded_alpha_sum = array_ops.expand_dims(alpha_sum, -1)
                shared_factor = n * (expanded_alpha_sum + n) / (
                    expanded_alpha_sum + 1) * array_ops.ones_like(alpha)

                mean_no_n = alpha / expanded_alpha_sum
                expanded_mean_no_n = array_ops.expand_dims(mean_no_n, -1)
                variance = -math_ops.batch_matmul(
                    expanded_mean_no_n, expanded_mean_no_n, adj_y=True)
                variance += array_ops.batch_matrix_diag(mean_no_n)
                variance *= array_ops.expand_dims(shared_factor, -1)
                return variance
  def variance(self, name='mean'):
    """Class variances for every batch member.

    The variance for each batch member is defined as the following:

    ```
    Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
      (n + alpha_0) / (1 + alpha_0)
    ```

    where `alpha_0 = sum_j alpha_j`.

    The covariance between elements in a batch is defined as:

    ```
    Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 *
      (n + alpha_0) / (1 + alpha_0)
    ```

    Args:
      name: The name for this op.

    Returns:
      A `Tensor` representing the variances for each batch member.
    """
    alpha = self._alpha
    alpha_sum = self._alpha_sum
    n = self._n
    with ops.name_scope(self.name):
      with ops.op_scope([alpha, alpha_sum, n], name):
        expanded_alpha_sum = array_ops.expand_dims(alpha_sum, -1)
        shared_factor = n * (expanded_alpha_sum + n) / (
            expanded_alpha_sum + 1) * array_ops.ones_like(alpha)

        mean_no_n = alpha / expanded_alpha_sum
        expanded_mean_no_n = array_ops.expand_dims(mean_no_n, -1)
        variance = -math_ops.batch_matmul(
            expanded_mean_no_n, expanded_mean_no_n, adj_y=True)
        variance += array_ops.batch_matrix_diag(mean_no_n)
        variance *= array_ops.expand_dims(shared_factor, -1)
        return variance
Exemple #47
0
def _BatchMatMul(op, grad):
    """Returns the gradient of x and y given the gradient of x * y."""
    x = op.inputs[0]
    y = op.inputs[1]
    adj_x = op.get_attr("adj_x")
    adj_y = op.get_attr("adj_y")

    if not adj_x:
        if not adj_y:
            grad_x = math_ops.batch_matmul(grad, y, False, True)
            grad_y = math_ops.batch_matmul(x, grad, True, False)
        else:
            grad_x = math_ops.batch_matmul(grad, y, False, False)
            grad_y = math_ops.batch_matmul(grad, x, True, False)
    else:
        if not adj_y:
            grad_x = math_ops.batch_matmul(y, grad, False, True)
            grad_y = math_ops.batch_matmul(x, grad, False, False)
        else:
            grad_x = math_ops.batch_matmul(y, grad, True, True)
            grad_y = math_ops.batch_matmul(grad, x, True, True)

    return grad_x, grad_y
Exemple #48
0
def _BatchMatMul(op, grad):
  """Returns the gradient of x and y given the gradient of x * y."""
  x = op.inputs[0]
  y = op.inputs[1]
  adj_x = op.get_attr("adj_x")
  adj_y = op.get_attr("adj_y")

  if not adj_x:
    if not adj_y:
      grad_x = math_ops.batch_matmul(grad, y, False, True)
      grad_y = math_ops.batch_matmul(x, grad, True, False)
    else:
      grad_x = math_ops.batch_matmul(grad, y, False, False)
      grad_y = math_ops.batch_matmul(grad, x, True, False)
  else:
    if not adj_y:
      grad_x = math_ops.batch_matmul(y, grad, False, True)
      grad_y = math_ops.batch_matmul(x, grad, False, False)
    else:
      grad_x = math_ops.batch_matmul(y, grad, True, True)
      grad_y = math_ops.batch_matmul(grad, x, True, True)

  return grad_x, grad_y
Exemple #49
0
  def log_pdf(self, x, name=None):
    """Log pdf of observations `x` given these Multivariate Normals.

    Args:
      x: tensor of dtype `dtype`, must be broadcastable with `mu`.
      name: The name to give this op.

    Returns:
      log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
    """
    with ops.op_scope(
        [self._mu, self._sigma_chol, x], name, "MultivariateNormalLogPdf"):
      x = ops.convert_to_tensor(x)
      contrib_tensor_util.assert_same_float_dtype((self._mu, x))

      x_centered = x - self.mu

      x_rank = array_ops.rank(x_centered)
      sigma_rank = array_ops.rank(self._sigma_chol)

      x_rank_vec = array_ops.pack([x_rank])
      sigma_rank_vec = array_ops.pack([sigma_rank])
      x_shape = array_ops.shape(x_centered)

      # sigma_chol is shaped [D, E, F, ..., k, k]
      # x_centered shape is one of:
      #   [D, E, F, ..., k], or [F, ..., k], or
      #   [A, B, C, D, E, F, ..., k]
      # and we need to convert x_centered to shape:
      #   [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist)
      # then transpose and reshape x_whitened back to one of the shapes:
      #   [D, E, F, ..., k], or [1, 1, F, ..., k], or
      #   [A, B, C, D, E, F, ..., k]

      # This helper handles the case where rank(x_centered) < rank(sigma)
      def _broadcast_x_not_higher_rank_than_sigma():
        return array_ops.reshape(
            x_centered,
            array_ops.concat(
                # Reshape to ones(deficient x rank) + x_shape + [1]
                0, (array_ops.ones(array_ops.pack([sigma_rank - x_rank - 1]),
                                   dtype=x_rank.dtype),
                    x_shape,
                    [1])))

      # These helpers handle the case where rank(x_centered) >= rank(sigma)
      def _broadcast_x_higher_rank_than_sigma():
        x_shape_left = array_ops.slice(
            x_shape, [0], sigma_rank_vec - 1)
        x_shape_right = array_ops.slice(
            x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
        x_shape_perm = array_ops.concat(
            0, (math_ops.range(sigma_rank - 1, x_rank),
                math_ops.range(0, sigma_rank - 1)))
        return array_ops.reshape(
            # Convert to [D, E, F, ..., k, B, C]
            array_ops.transpose(
                x_centered, perm=x_shape_perm),
            # Reshape to [D, E, F, ..., k, B*C]
            array_ops.concat(
                0, (x_shape_right,
                    array_ops.pack([
                        math_ops.reduce_prod(x_shape_left, 0)]))))

      def _unbroadcast_x_higher_rank_than_sigma():
        x_shape_left = array_ops.slice(
            x_shape, [0], sigma_rank_vec - 1)
        x_shape_right = array_ops.slice(
            x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
        x_shape_perm = array_ops.concat(
            0, (math_ops.range(sigma_rank - 1, x_rank),
                math_ops.range(0, sigma_rank - 1)))
        return array_ops.transpose(
            # [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k]
            array_ops.reshape(
                # convert to [D, E, F, ..., k, B, C]
                x_whitened_broadcast,
                array_ops.concat(0, (x_shape_right, x_shape_left))),
            perm=x_shape_perm)

      # Step 1: reshape x_centered
      x_centered_broadcast = control_flow_ops.cond(
          # x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1]
          # or         == [F, ..., k] => [1, 1, F, ..., k, 1]
          x_rank <= sigma_rank - 1,
          _broadcast_x_not_higher_rank_than_sigma,
          # x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C]
          _broadcast_x_higher_rank_than_sigma)

      x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve(
          self._sigma_chol, x_centered_broadcast)

      # Reshape x_whitened_broadcast back to x_whitened
      x_whitened = control_flow_ops.cond(
          x_rank <= sigma_rank - 1,
          lambda: array_ops.reshape(x_whitened_broadcast, x_shape),
          _unbroadcast_x_higher_rank_than_sigma)

      x_whitened = array_ops.expand_dims(x_whitened, -1)
      # Reshape x_whitened to contain row vectors
      # Returns a batchwise scalar
      x_whitened_norm = math_ops.batch_matmul(
          x_whitened, x_whitened, adj_x=True)
      x_whitened_norm = control_flow_ops.cond(
          x_rank <= sigma_rank - 1,
          lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]),
          lambda: array_ops.squeeze(x_whitened_norm, [-1]))

      log_two_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype)
      k = math_ops.cast(self._k, self.dtype)
      log_pdf_value = (
          -math_ops.log(self._sigma_det) -k * log_two_pi - x_whitened_norm) / 2
      final_shaped_value = control_flow_ops.cond(
          x_rank <= sigma_rank - 1,
          lambda: log_pdf_value,
          lambda: array_ops.squeeze(log_pdf_value, [-1]))

      output_static_shape = x_centered.get_shape()[:-1]
      final_shaped_value.set_shape(output_static_shape)
      return final_shaped_value
 def _to_dense(self):
   sqrt = self.sqrt_to_dense()
   return math_ops.batch_matmul(sqrt, sqrt, adj_y=True)
Exemple #51
0
  def __init__(self, mu, sigma=None, sigma_chol=None, name=None):
    """Multivariate Normal distributions on `R^k`.

    User must provide means `mu`, which are tensors of rank `N+1` (`N >= 0`)
    with the last dimension having length `k`.

    User must provide exactly one of `sigma` (the covariance matrices) or
    `sigma_chol` (the cholesky decompositions of the covariance matrices).
    `sigma` or `sigma_chol` must be of rank `N+2`.  The last two dimensions
    must both have length `k`.  The first `N` dimensions correspond to batch
    indices.

    If `sigma_chol` is not provided, the batch cholesky factorization of `sigma`
    is calculated for you.

    The shapes of `mu` and `sigma` must match for the first `N` dimensions.

    Regardless of which parameter is provided, the covariance matrices must all
    be **positive definite** (an error is raised if one of them is not).

    Args:
      mu: (N+1)-D.  `float` or `double` tensor, the means of the distributions.
      sigma: (N+2)-D.  (optional) `float` or `double` tensor, the covariances
        of the distribution(s).  The first `N+1` dimensions must match
        those of `mu`.  Must be batch-positive-definite.
      sigma_chol: (N+2)-D.  (optional) `float` or `double` tensor, a
        lower-triangular factorization of `sigma`
        (`sigma = sigma_chol . sigma_chol^*`).  The first `N+1` dimensions
        must match those of `mu`.  The tensor itself need not be batch
        lower triangular: we ignore the upper triangular part.  However,
        the batch diagonals must be positive (i.e., sigma_chol must be
        batch-positive-definite).
      name: The name to give Ops created by the initializer.

    Raises:
      ValueError: if neither sigma nor sigma_chol is provided.
      TypeError: if mu and sigma (resp. sigma_chol) are different dtypes.
    """
    if (sigma is None) == (sigma_chol is None):
      raise ValueError("Exactly one of sigma and sigma_chol must be provided")

    with ops.op_scope([mu, sigma, sigma_chol], name, "MultivariateNormal"):
      sigma_or_half = sigma_chol if sigma is None else sigma

      mu = ops.convert_to_tensor(mu)
      sigma_or_half = ops.convert_to_tensor(sigma_or_half)

      contrib_tensor_util.assert_same_float_dtype((mu, sigma_or_half))

      with ops.control_dependencies([
          _assert_compatible_shapes(mu, sigma_or_half)]):
        mu = array_ops.identity(mu, name="mu")

        # Store the dimensionality of the MVNs
        self._k = array_ops.gather(array_ops.shape(mu), array_ops.rank(mu) - 1)

        if sigma_chol is not None:
          # Ensure we only keep the lower triangular part.
          sigma_chol = array_ops.batch_matrix_band_part(
              sigma_chol, num_lower=-1, num_upper=0)
          sigma_det = _determinant_from_sigma_chol(sigma_chol)
          with ops.control_dependencies([
              _assert_batch_positive_definite(sigma_chol)]):
            self._sigma = math_ops.batch_matmul(
                sigma_chol, sigma_chol, adj_y=True, name="sigma")
            self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
            self._sigma_det = array_ops.identity(sigma_det, "sigma_det")
            self._mu = array_ops.identity(mu, "mu")
        else:  # sigma is not None
          sigma_chol = linalg_ops.batch_cholesky(sigma)
          sigma_det = _determinant_from_sigma_chol(sigma_chol)
          # batch_cholesky checks for PSD; so we can just use it here.
          with ops.control_dependencies([sigma_chol]):
            self._sigma = array_ops.identity(sigma, "sigma")
            self._sigma_chol = array_ops.identity(sigma_chol, "sigma_chol")
            self._sigma_det = array_ops.identity(sigma_det, "sigma_det")
            self._mu = array_ops.identity(mu, "mu")
Exemple #52
0
    def sample_n(self, n, seed=None, name="sample"):
        # pylint: disable=line-too-long
        """Generate `n` samples.

    Complexity: O(nbk^3)

    The sampling procedure is based on the [Bartlett decomposition](
    https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition)
    and [using a Gamma distribution to generate Chi2 random variates](
    https://en.wikipedia.org/wiki/Chi-squared_distribution#Gamma.2C_exponential.2C_and_related_distributions).

    Args:
      n: `Scalar` `Tensor` of type `int32` or `int64`, the number of
        observations to sample.
      seed: Python integer; random number generator seed.
      name: The name of this op.

    Returns:
      samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
          with values of type `self.dtype`.
    """
        with ops.name_scope(self.name):
            with ops.name_scope(name, values=[n] + list(self.inputs.values())):
                n = ops.convert_to_tensor(n, name="n")
                if n.dtype != dtypes.int32:
                    raise TypeError("n.dtype=%s which is not int32" % n.dtype)
                batch_shape = self.batch_shape()
                event_shape = self.event_shape()
                batch_ndims = array_ops.shape(batch_shape)[0]

                ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
                shape = array_ops.concat(0, ((n,), batch_shape, event_shape))

                # Complexity: O(nbk^2)
                x = random_ops.random_normal(shape=shape, mean=0.0, stddev=1.0, dtype=self.dtype, seed=seed)

                # Complexity: O(nbk)
                # This parametrization is equivalent to Chi2, i.e.,
                # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
                g = random_ops.random_gamma(
                    shape=(n,),
                    alpha=self._multi_gamma_sequence(0.5 * self.df, self.dimension),
                    beta=0.5,
                    dtype=self.dtype,
                    seed=seed,
                )

                # Complexity: O(nbk^2)
                x = array_ops.batch_matrix_band_part(x, -1, 0)  # Tri-lower.

                # Complexity: O(nbk)
                x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g))

                # Make batch-op ready.
                # Complexity: O(nbk^2)
                perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,)))
                x = array_ops.transpose(x, perm)
                shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1)))
                x = array_ops.reshape(x, shape)

                # Complexity: O(nbM) where M is the complexity of the operator solving a
                # vector system.  E.g., for OperatorPDDiag, each matmul is O(k^2), so
                # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is
                # O(k^3) so this step has complexity O(nbk^3).
                x = self.scale_operator_pd.sqrt_matmul(x)

                # Undo make batch-op ready.
                # Complexity: O(nbk^2)
                shape = array_ops.concat(0, (batch_shape, event_shape, (n,)))
                x = array_ops.reshape(x, shape)
                perm = array_ops.concat(0, ((ndims - 1,), math_ops.range(0, ndims - 1)))
                x = array_ops.transpose(x, perm)

                if not self.cholesky_input_output_matrices:
                    # Complexity: O(nbk^3)
                    x = math_ops.batch_matmul(x, x, adj_y=True)

                # Set shape hints.
                if self.scale_operator_pd.get_shape().ndims is not None:
                    x.set_shape(
                        tensor_shape.TensorShape(
                            [tensor_util.constant_value(n)] + self.scale_operator_pd.get_shape().as_list()
                        )
                    )
                elif x.get_shape().ndims is not None:
                    x.get_shape()[0].merge_with(tensor_shape.TensorDimension(tensor_util.constant_value(n)))

                return x