Exemplo n.º 1
0
 def _variance(self):
     scale = self.alpha_sum * math_ops.sqrt(1.0 + self.alpha_sum)
     alpha = self.alpha / scale
     outer_prod = -math_ops.batch_matmul(
         array_ops.expand_dims(alpha, dim=-1), array_ops.expand_dims(alpha, dim=-2)  # column
     )  # row
     return array_ops.batch_matrix_set_diag(outer_prod, alpha * (self.alpha_sum / scale - alpha))
Exemplo n.º 2
0
 def _variance(self):
   p = self.p * array_ops.expand_dims(array_ops.ones_like(self.n), -1)
   outer_prod = math_ops.batch_matmul(
       array_ops.expand_dims(self._mean_val, -1),
       array_ops.expand_dims(p, -2))
   return array_ops.batch_matrix_set_diag(
       -outer_prod, self._mean_val - self._mean_val * p)
Exemplo n.º 3
0
def batch_matrix_diag_transform(matrix, transform=None, name=None):
  """Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.

  Create a trainable covariance defined by a Cholesky factor:

  ```python
  # Transform network layer into 2 x 2 array.
  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))

  # Make the diagonal positive.  If the upper triangle was zero, this would be a
  # valid Cholesky factor.
  chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus)

  # OperatorPDCholesky ignores the upper triangle.
  operator = OperatorPDCholesky(chol)
  ```

  Example of heteroskedastic 2-D linear regression.

  ```python
  # Get a trainable Cholesky factor.
  matrix_values = tf.contrib.layers.fully_connected(activations, 4)
  matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
  chol = batch_matrix_diag_transform(matrix, transform=tf.nn.softplus)

  # Get a trainable mean.
  mu = tf.contrib.layers.fully_connected(activations, 2)

  # This is a fully trainable multivariate normal!
  dist = tf.contrib.distributions.MVNCholesky(mu, chol)

  # Standard log loss.  Minimizing this will "train" mu and chol, and then dist
  # will be a distribution predicting labels as multivariate Gaussians.
  loss = -1 * tf.reduce_mean(dist.log_pdf(labels))
  ```

  Args:
    matrix:  Rank `R` `Tensor`, `R >= 2`, where the last two dimensions are
      equal.
    transform:  Element-wise function mapping `Tensors` to `Tensors`.  To
      be applied to the diagonal of `matrix`.  If `None`, `matrix` is returned
      unchanged.  Defaults to `None`.
    name:  A name to give created ops.
      Defaults to "batch_matrix_diag_transform".

  Returns:
    A `Tensor` with same shape and `dtype` as `matrix`.
  """
  with ops.name_scope(name, "batch_matrix_diag_transform", [matrix]):
    matrix = ops.convert_to_tensor(matrix, name="matrix")
    if transform is None:
      return matrix
    # Replace the diag with transformed diag.
    diag = array_ops.batch_matrix_diag_part(matrix)
    transformed_diag = transform(diag)
    transformed_mat = array_ops.batch_matrix_set_diag(matrix, transformed_diag)

  return transformed_mat
Exemplo n.º 4
0
  def _sample_n(self, n, seed):
    batch_shape = self.batch_shape()
    event_shape = self.event_shape()
    batch_ndims = array_ops.shape(batch_shape)[0]

    ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
    shape = array_ops.concat(0, ((n,), batch_shape, event_shape))

    # Complexity: O(nbk^2)
    x = random_ops.random_normal(shape=shape,
                                 mean=0.,
                                 stddev=1.,
                                 dtype=self.dtype,
                                 seed=seed)

    # Complexity: O(nbk)
    # This parametrization is equivalent to Chi2, i.e.,
    # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
    g = random_ops.random_gamma(shape=(n,),
                                alpha=self._multi_gamma_sequence(
                                    0.5 * self.df, self.dimension),
                                beta=0.5,
                                dtype=self.dtype,
                                seed=seed)

    # Complexity: O(nbk^2)
    x = array_ops.batch_matrix_band_part(x, -1, 0)  # Tri-lower.

    # Complexity: O(nbk)
    x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g))

    # Make batch-op ready.
    # Complexity: O(nbk^2)
    perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,)))
    x = array_ops.transpose(x, perm)
    shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1)))
    x = array_ops.reshape(x, shape)

    # Complexity: O(nbM) where M is the complexity of the operator solving a
    # vector system.  E.g., for OperatorPDDiag, each matmul is O(k^2), so
    # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is
    # O(k^3) so this step has complexity O(nbk^3).
    x = self.scale_operator_pd.sqrt_matmul(x)

    # Undo make batch-op ready.
    # Complexity: O(nbk^2)
    shape = array_ops.concat(0, (batch_shape, event_shape, (n,)))
    x = array_ops.reshape(x, shape)
    perm = array_ops.concat(0, ((ndims-1,), math_ops.range(0, ndims-1)))
    x = array_ops.transpose(x, perm)

    if not self.cholesky_input_output_matrices:
      # Complexity: O(nbk^3)
      x = math_ops.batch_matmul(x, x, adj_y=True)

    return x
Exemplo n.º 5
0
 def _variance(self):
   alpha_sum = array_ops.expand_dims(self.alpha_sum, -1)
   normalized_alpha = self.alpha / alpha_sum
   variance = -math_ops.batch_matmul(
       array_ops.expand_dims(normalized_alpha, -1),
       array_ops.expand_dims(normalized_alpha, -2))
   variance = array_ops.batch_matrix_set_diag(
       variance, normalized_alpha * (1. - normalized_alpha))
   shared_factor = (self.n * (alpha_sum + self.n) /
                    (alpha_sum + 1) * array_ops.ones_like(self.alpha))
   variance *= array_ops.expand_dims(shared_factor, -1)
   return variance
Exemplo n.º 6
0
def _BatchMatrixSetDiagGrad(op, grad):
  diag_shape = op.inputs[1].get_shape()
  diag_shape = diag_shape.merge_with(op.inputs[0].get_shape()[:-1])
  diag_shape = diag_shape.merge_with(grad.get_shape()[:-1])
  if diag_shape.is_fully_defined():
    diag_shape = diag_shape.as_list()
  else:
    diag_shape = array_ops.shape(grad)
    diag_shape = array_ops.slice(diag_shape, [0], [array_ops.rank(grad) - 1])
  grad_input = array_ops.batch_matrix_set_diag(
      grad, array_ops.zeros(diag_shape, dtype=grad.dtype))
  grad_diag = array_ops.batch_matrix_diag_part(grad)
  return (grad_input, grad_diag)
Exemplo n.º 7
0
def _BatchMatrixSetDiagGrad(op, grad):
    diag_shape = op.inputs[1].get_shape()
    diag_shape = diag_shape.merge_with(op.inputs[0].get_shape()[:-1])
    diag_shape = diag_shape.merge_with(grad.get_shape()[:-1])
    if diag_shape.is_fully_defined():
        diag_shape = diag_shape.as_list()
    else:
        diag_shape = array_ops.shape(grad)
        diag_shape = array_ops.slice(diag_shape, [0],
                                     [array_ops.rank(grad) - 1])
    grad_input = array_ops.batch_matrix_set_diag(
        grad, array_ops.zeros(diag_shape, dtype=grad.dtype))
    grad_diag = array_ops.batch_matrix_diag_part(grad)
    return (grad_input, grad_diag)
Exemplo n.º 8
0
  def sample_n(self, n, seed=None, name='sample'):
    # pylint: disable=line-too-long
    """Generate `n` samples.

    Complexity: O(nbk^3)

    The sampling procedure is based on the [Bartlett decomposition](
    https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition)
    and [using a Gamma distribution to generate Chi2 random variates](
    https://en.wikipedia.org/wiki/Chi-squared_distribution#Gamma.2C_exponential.2C_and_related_distributions).

    Args:
      n: `Scalar` `Tensor` of type `int32` or `int64`, the number of
        observations to sample.
      seed: Python integer; random number generator seed.
      name: The name of this op.

    Returns:
      samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
          with values of type `self.dtype`.
    """
    with ops.name_scope(self.name):
      with ops.name_scope(name, values=[n] + list(self.inputs.values())):
        n = ops.convert_to_tensor(n, name='n')
        if n.dtype != dtypes.int32:
          raise TypeError('n.dtype=%s which is not int32' % n.dtype)
        batch_shape = self.batch_shape()
        event_shape = self.event_shape()
        batch_ndims = array_ops.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = array_ops.concat(0, ((n,), batch_shape, event_shape))

        # Complexity: O(nbk^2)
        x = random_ops.random_normal(shape=shape,
                                     mean=0.,
                                     stddev=1.,
                                     dtype=self.dtype,
                                     seed=seed)

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        g = random_ops.random_gamma(shape=(n,),
                                    alpha=self._multi_gamma_sequence(
                                        0.5 * self.df, self.dimension),
                                    beta=0.5,
                                    dtype=self.dtype,
                                    seed=seed)

        # Complexity: O(nbk^2)
        x = array_ops.batch_matrix_band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk^2)
        perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,)))
        x = array_ops.transpose(x, perm)
        shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1)))
        x = array_ops.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system.  E.g., for OperatorPDDiag, each matmul is O(k^2), so
        # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is
        # O(k^3) so this step has complexity O(nbk^3).
        x = self.scale_operator_pd.sqrt_matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk^2)
        shape = array_ops.concat(0, (batch_shape, event_shape, (n,)))
        x = array_ops.reshape(x, shape)
        perm = array_ops.concat(0, ((ndims-1,), math_ops.range(0, ndims-1)))
        x = array_ops.transpose(x, perm)

        if not self.cholesky_input_output_matrices:
          # Complexity: O(nbk^3)
          x = math_ops.batch_matmul(x, x, adj_y=True)

        # Set shape hints.
        if self.scale_operator_pd.get_shape().ndims is not None:
          x.set_shape(tensor_shape.TensorShape(
              [tensor_util.constant_value(n)] +
              self.scale_operator_pd.get_shape().as_list()))
        elif x.get_shape().ndims is not None:
          x.get_shape()[0].merge_with(
              tensor_shape.TensorDimension(tensor_util.constant_value(n)))

        return x
Exemplo n.º 9
0
 def _add_to_tensor(self, mat):
   mat_diag = array_ops.batch_matrix_diag_part(mat)
   new_diag = math_ops.square(self._diag) + mat_diag
   return array_ops.batch_matrix_set_diag(mat, new_diag)
Exemplo n.º 10
0
 def _add_to_tensor(self, mat):
     mat_diag = array_ops.batch_matrix_diag_part(mat)
     new_diag = math_ops.square(self._diag) + mat_diag
     return array_ops.batch_matrix_set_diag(mat, new_diag)
Exemplo n.º 11
0
    def sample_n(self, n, seed=None, name="sample"):
        # pylint: disable=line-too-long
        """Generate `n` samples.

    Complexity: O(nbk^3)

    The sampling procedure is based on the [Bartlett decomposition](
    https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition)
    and [using a Gamma distribution to generate Chi2 random variates](
    https://en.wikipedia.org/wiki/Chi-squared_distribution#Gamma.2C_exponential.2C_and_related_distributions).

    Args:
      n: `Scalar` `Tensor` of type `int32` or `int64`, the number of
        observations to sample.
      seed: Python integer; random number generator seed.
      name: The name of this op.

    Returns:
      samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
          with values of type `self.dtype`.
    """
        with ops.name_scope(self.name):
            with ops.name_scope(name, values=[n] + list(self.inputs.values())):
                n = ops.convert_to_tensor(n, name="n")
                if n.dtype != dtypes.int32:
                    raise TypeError("n.dtype=%s which is not int32" % n.dtype)
                batch_shape = self.batch_shape()
                event_shape = self.event_shape()
                batch_ndims = array_ops.shape(batch_shape)[0]

                ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
                shape = array_ops.concat(0, ((n,), batch_shape, event_shape))

                # Complexity: O(nbk^2)
                x = random_ops.random_normal(shape=shape, mean=0.0, stddev=1.0, dtype=self.dtype, seed=seed)

                # Complexity: O(nbk)
                # This parametrization is equivalent to Chi2, i.e.,
                # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
                g = random_ops.random_gamma(
                    shape=(n,),
                    alpha=self._multi_gamma_sequence(0.5 * self.df, self.dimension),
                    beta=0.5,
                    dtype=self.dtype,
                    seed=seed,
                )

                # Complexity: O(nbk^2)
                x = array_ops.batch_matrix_band_part(x, -1, 0)  # Tri-lower.

                # Complexity: O(nbk)
                x = array_ops.batch_matrix_set_diag(x, math_ops.sqrt(g))

                # Make batch-op ready.
                # Complexity: O(nbk^2)
                perm = array_ops.concat(0, (math_ops.range(1, ndims), (0,)))
                x = array_ops.transpose(x, perm)
                shape = array_ops.concat(0, (batch_shape, (event_shape[0], -1)))
                x = array_ops.reshape(x, shape)

                # Complexity: O(nbM) where M is the complexity of the operator solving a
                # vector system.  E.g., for OperatorPDDiag, each matmul is O(k^2), so
                # this complexity is O(nbk^2). For OperatorPDCholesky, each matmul is
                # O(k^3) so this step has complexity O(nbk^3).
                x = self.scale_operator_pd.sqrt_matmul(x)

                # Undo make batch-op ready.
                # Complexity: O(nbk^2)
                shape = array_ops.concat(0, (batch_shape, event_shape, (n,)))
                x = array_ops.reshape(x, shape)
                perm = array_ops.concat(0, ((ndims - 1,), math_ops.range(0, ndims - 1)))
                x = array_ops.transpose(x, perm)

                if not self.cholesky_input_output_matrices:
                    # Complexity: O(nbk^3)
                    x = math_ops.batch_matmul(x, x, adj_y=True)

                # Set shape hints.
                if self.scale_operator_pd.get_shape().ndims is not None:
                    x.set_shape(
                        tensor_shape.TensorShape(
                            [tensor_util.constant_value(n)] + self.scale_operator_pd.get_shape().as_list()
                        )
                    )
                elif x.get_shape().ndims is not None:
                    x.get_shape()[0].merge_with(tensor_shape.TensorDimension(tensor_util.constant_value(n)))

                return x
Exemplo n.º 12
0
 def _add_to_tensor(self, mat):
   # Add to a tensor in O(k) time!
   mat_diag = array_ops.batch_matrix_diag_part(mat)
   new_diag = constant_op.constant(1, dtype=self.dtype) + mat_diag
   return array_ops.batch_matrix_set_diag(mat, new_diag)
Exemplo n.º 13
0
 def _add_to_tensor(self, mat):
     # Add to a tensor in O(k) time!
     mat_diag = array_ops.batch_matrix_diag_part(mat)
     new_diag = constant_op.constant(1, dtype=self.dtype) + mat_diag
     return array_ops.batch_matrix_set_diag(mat, new_diag)