Exemplo n.º 1
0
 def _inverse(self, x):
     x -= self.loc
     x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x)
     x = linalg_ops.batch_matrix_triangular_solve(self.scale, x)
     x = self.shaper.undo_make_batch_of_event_sample_matrices(
         x, sample_shape)
     return x
Exemplo n.º 2
0
    def _x_whitened_if_should_flip(self, x):
        # Tensor to use if x.shape = [M1,...,Mm] + chol.shape[:-1],
        # which is common if x was sampled.
        x_flipped = self._flip_front_dims_to_back(x)

        # batch version of: L^{-1} x
        x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve(
            self._chol, x_flipped)

        return self._unfip_back_dims_to_front(x_whitened_expanded,
                                              array_ops.shape(x),
                                              x.get_shape())
Exemplo n.º 3
0
  def _x_whitened_if_should_flip(self, x):
    # Tensor to use if x.shape = [M1,...,Mm] + chol.shape[:-1],
    # which is common if x was sampled.
    x_flipped = self._flip_front_dims_to_back(x)

    # batch version of: L^{-1} x
    x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve(
        self._chol, x_flipped)

    return self._unfip_back_dims_to_front(
        x_whitened_expanded,
        array_ops.shape(x),
        x.get_shape())
Exemplo n.º 4
0
    def _x_whitened_if_no_flip(self, x):
        """x_whitened in the event of no flip."""
        # Tensors to use if x and chol have same shape, or a shape that must be
        # broadcast to match.
        chol_bcast, x_bcast = self._get_chol_and_x_compatible_shape(x)

        # batch version of: L^{-1} x
        # Note that here x_bcast has trailing dims of (k, 1), for "1" system of k
        # linear equations.  This is the form used by the solver.
        x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve(
            chol_bcast, x_bcast)

        x_whitened = array_ops.squeeze(x_whitened_expanded, squeeze_dims=[-1])
        return x_whitened
Exemplo n.º 5
0
  def _x_whitened_if_no_flip(self, x):
    """x_whitened in the event of no flip."""
    # Tensors to use if x and chol have same shape, or a shape that must be
    # broadcast to match.
    chol_bcast, x_bcast = self._get_chol_and_x_compatible_shape(x)

    # batch version of: L^{-1} x
    # Note that here x_bcast has trailing dims of (k, 1), for "1" system of k
    # linear equations.  This is the form used by the solver.
    x_whitened_expanded = linalg_ops.batch_matrix_triangular_solve(
        chol_bcast, x_bcast)

    x_whitened = array_ops.squeeze(x_whitened_expanded, squeeze_dims=[-1])
    return x_whitened
Exemplo n.º 6
0
def _BatchMatrixTriangularSolveGrad(op, grad):
    """Gradient for BatchMatrixTriangularSolve."""
    a = op.inputs[0]
    adjoint_a = op.get_attr("adjoint")
    lower_a = op.get_attr("lower")
    c = op.outputs[0]
    grad_b = linalg_ops.batch_matrix_triangular_solve(a,
                                                      grad,
                                                      lower=lower_a,
                                                      adjoint=not adjoint_a)
    if adjoint_a:
        grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True)
    else:
        grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
    if lower_a:
        grad_a = array_ops.batch_matrix_band_part(grad_a, -1, 0)
    else:
        grad_a = array_ops.batch_matrix_band_part(grad_a, 0, -1)
    return (grad_a, grad_b)
Exemplo n.º 7
0
def _BatchMatrixTriangularSolveGrad(op, grad):
  """Gradient for BatchMatrixTriangularSolve."""
  a = op.inputs[0]
  adjoint_a = op.get_attr("adjoint")
  lower_a = op.get_attr("lower")
  c = op.outputs[0]
  grad_b = linalg_ops.batch_matrix_triangular_solve(a,
                                                    grad,
                                                    lower=lower_a,
                                                    adjoint=not adjoint_a)
  if adjoint_a:
    grad_a = -math_ops.batch_matmul(c, grad_b, adj_y=True)
  else:
    grad_a = -math_ops.batch_matmul(grad_b, c, adj_y=True)
  if lower_a:
    grad_a = array_ops.batch_matrix_band_part(grad_a, -1, 0)
  else:
    grad_a = array_ops.batch_matrix_band_part(grad_a, 0, -1)
  return (grad_a, grad_b)
Exemplo n.º 8
0
    def log_pdf(self, x, name=None):
        """Log pdf of observations `x` given these Multivariate Normals.

    Args:
      x: tensor of dtype `dtype`, must be broadcastable with `mu`.
      name: The name to give this op.

    Returns:
      log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
    """
        with ops.op_scope([self._mu, self._sigma_chol, x], name,
                          "MultivariateNormalLogPdf"):
            x = ops.convert_to_tensor(x)
            contrib_tensor_util.assert_same_float_dtype((self._mu, x))

            x_centered = x - self.mu

            x_rank = array_ops.rank(x_centered)
            sigma_rank = array_ops.rank(self._sigma_chol)

            x_rank_vec = array_ops.pack([x_rank])
            sigma_rank_vec = array_ops.pack([sigma_rank])
            x_shape = array_ops.shape(x_centered)

            # sigma_chol is shaped [D, E, F, ..., k, k]
            # x_centered shape is one of:
            #   [D, E, F, ..., k], or [F, ..., k], or
            #   [A, B, C, D, E, F, ..., k]
            # and we need to convert x_centered to shape:
            #   [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist)
            # then transpose and reshape x_whitened back to one of the shapes:
            #   [D, E, F, ..., k], or [1, 1, F, ..., k], or
            #   [A, B, C, D, E, F, ..., k]

            # This helper handles the case where rank(x_centered) < rank(sigma)
            def _broadcast_x_not_higher_rank_than_sigma():
                return array_ops.reshape(
                    x_centered,
                    array_ops.concat(
                        # Reshape to ones(deficient x rank) + x_shape + [1]
                        0,
                        (array_ops.ones(array_ops.pack(
                            [sigma_rank - x_rank - 1]),
                                        dtype=x_rank.dtype), x_shape, [1])))

            # These helpers handle the case where rank(x_centered) >= rank(sigma)
            def _broadcast_x_higher_rank_than_sigma():
                x_shape_left = array_ops.slice(x_shape, [0],
                                               sigma_rank_vec - 1)
                x_shape_right = array_ops.slice(x_shape, sigma_rank_vec - 1,
                                                x_rank_vec - 1)
                x_shape_perm = array_ops.concat(
                    0, (math_ops.range(sigma_rank - 1, x_rank),
                        math_ops.range(0, sigma_rank - 1)))
                return array_ops.reshape(
                    # Convert to [D, E, F, ..., k, B, C]
                    array_ops.transpose(x_centered, perm=x_shape_perm),
                    # Reshape to [D, E, F, ..., k, B*C]
                    array_ops.concat(
                        0, (x_shape_right,
                            array_ops.pack(
                                [math_ops.reduce_prod(x_shape_left, 0)]))))

            def _unbroadcast_x_higher_rank_than_sigma():
                x_shape_left = array_ops.slice(x_shape, [0],
                                               sigma_rank_vec - 1)
                x_shape_right = array_ops.slice(x_shape, sigma_rank_vec - 1,
                                                x_rank_vec - 1)
                x_shape_perm = array_ops.concat(
                    0, (math_ops.range(sigma_rank - 1, x_rank),
                        math_ops.range(0, sigma_rank - 1)))
                return array_ops.transpose(
                    # [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k]
                    array_ops.reshape(
                        # convert to [D, E, F, ..., k, B, C]
                        x_whitened_broadcast,
                        array_ops.concat(0, (x_shape_right, x_shape_left))),
                    perm=x_shape_perm)

            # Step 1: reshape x_centered
            x_centered_broadcast = control_flow_ops.cond(
                # x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1]
                # or         == [F, ..., k] => [1, 1, F, ..., k, 1]
                x_rank <= sigma_rank - 1,
                _broadcast_x_not_higher_rank_than_sigma,
                # x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C]
                _broadcast_x_higher_rank_than_sigma)

            x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve(
                self._sigma_chol, x_centered_broadcast)

            # Reshape x_whitened_broadcast back to x_whitened
            x_whitened = control_flow_ops.cond(
                x_rank <= sigma_rank - 1,
                lambda: array_ops.reshape(x_whitened_broadcast, x_shape),
                _unbroadcast_x_higher_rank_than_sigma)

            x_whitened = array_ops.expand_dims(x_whitened, -1)
            # Reshape x_whitened to contain row vectors
            # Returns a batchwise scalar
            x_whitened_norm = math_ops.batch_matmul(x_whitened,
                                                    x_whitened,
                                                    adj_x=True)
            x_whitened_norm = control_flow_ops.cond(
                x_rank <= sigma_rank - 1,
                lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]),
                lambda: array_ops.squeeze(x_whitened_norm, [-1]))

            log_two_pi = constant_op.constant(math.log(2 * math.pi),
                                              dtype=self.dtype)
            k = math_ops.cast(self._k, self.dtype)
            log_pdf_value = (-math_ops.log(self._sigma_det) - k * log_two_pi -
                             x_whitened_norm) / 2
            final_shaped_value = control_flow_ops.cond(
                x_rank <= sigma_rank - 1, lambda: log_pdf_value,
                lambda: array_ops.squeeze(log_pdf_value, [-1]))

            output_static_shape = x_centered.get_shape()[:-1]
            final_shaped_value.set_shape(output_static_shape)
            return final_shaped_value
Exemplo n.º 9
0
  def log_pdf(self, x, name=None):
    """Log pdf of observations `x` given these Multivariate Normals.

    Args:
      x: tensor of dtype `dtype`, must be broadcastable with `mu`.
      name: The name to give this op.

    Returns:
      log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
    """
    with ops.op_scope(
        [self._mu, self._sigma_chol, x], name, "MultivariateNormalLogPdf"):
      x = ops.convert_to_tensor(x)
      contrib_tensor_util.assert_same_float_dtype((self._mu, x))

      x_centered = x - self.mu

      x_rank = array_ops.rank(x_centered)
      sigma_rank = array_ops.rank(self._sigma_chol)

      x_rank_vec = array_ops.pack([x_rank])
      sigma_rank_vec = array_ops.pack([sigma_rank])
      x_shape = array_ops.shape(x_centered)

      # sigma_chol is shaped [D, E, F, ..., k, k]
      # x_centered shape is one of:
      #   [D, E, F, ..., k], or [F, ..., k], or
      #   [A, B, C, D, E, F, ..., k]
      # and we need to convert x_centered to shape:
      #   [D, E, F, ..., k, A*B*C] (or 1 if A, B, C don't exist)
      # then transpose and reshape x_whitened back to one of the shapes:
      #   [D, E, F, ..., k], or [1, 1, F, ..., k], or
      #   [A, B, C, D, E, F, ..., k]

      # This helper handles the case where rank(x_centered) < rank(sigma)
      def _broadcast_x_not_higher_rank_than_sigma():
        return array_ops.reshape(
            x_centered,
            array_ops.concat(
                # Reshape to ones(deficient x rank) + x_shape + [1]
                0, (array_ops.ones(array_ops.pack([sigma_rank - x_rank - 1]),
                                   dtype=x_rank.dtype),
                    x_shape,
                    [1])))

      # These helpers handle the case where rank(x_centered) >= rank(sigma)
      def _broadcast_x_higher_rank_than_sigma():
        x_shape_left = array_ops.slice(
            x_shape, [0], sigma_rank_vec - 1)
        x_shape_right = array_ops.slice(
            x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
        x_shape_perm = array_ops.concat(
            0, (math_ops.range(sigma_rank - 1, x_rank),
                math_ops.range(0, sigma_rank - 1)))
        return array_ops.reshape(
            # Convert to [D, E, F, ..., k, B, C]
            array_ops.transpose(
                x_centered, perm=x_shape_perm),
            # Reshape to [D, E, F, ..., k, B*C]
            array_ops.concat(
                0, (x_shape_right,
                    array_ops.pack([
                        math_ops.reduce_prod(x_shape_left, 0)]))))

      def _unbroadcast_x_higher_rank_than_sigma():
        x_shape_left = array_ops.slice(
            x_shape, [0], sigma_rank_vec - 1)
        x_shape_right = array_ops.slice(
            x_shape, sigma_rank_vec - 1, x_rank_vec - 1)
        x_shape_perm = array_ops.concat(
            0, (math_ops.range(sigma_rank - 1, x_rank),
                math_ops.range(0, sigma_rank - 1)))
        return array_ops.transpose(
            # [D, E, F, ..., k, B, C] => [B, C, D, E, F, ..., k]
            array_ops.reshape(
                # convert to [D, E, F, ..., k, B, C]
                x_whitened_broadcast,
                array_ops.concat(0, (x_shape_right, x_shape_left))),
            perm=x_shape_perm)

      # Step 1: reshape x_centered
      x_centered_broadcast = control_flow_ops.cond(
          # x_centered == [D, E, F, ..., k] => [D, E, F, ..., k, 1]
          # or         == [F, ..., k] => [1, 1, F, ..., k, 1]
          x_rank <= sigma_rank - 1,
          _broadcast_x_not_higher_rank_than_sigma,
          # x_centered == [B, C, D, E, F, ..., k] => [D, E, F, ..., k, B*C]
          _broadcast_x_higher_rank_than_sigma)

      x_whitened_broadcast = linalg_ops.batch_matrix_triangular_solve(
          self._sigma_chol, x_centered_broadcast)

      # Reshape x_whitened_broadcast back to x_whitened
      x_whitened = control_flow_ops.cond(
          x_rank <= sigma_rank - 1,
          lambda: array_ops.reshape(x_whitened_broadcast, x_shape),
          _unbroadcast_x_higher_rank_than_sigma)

      x_whitened = array_ops.expand_dims(x_whitened, -1)
      # Reshape x_whitened to contain row vectors
      # Returns a batchwise scalar
      x_whitened_norm = math_ops.batch_matmul(
          x_whitened, x_whitened, adj_x=True)
      x_whitened_norm = control_flow_ops.cond(
          x_rank <= sigma_rank - 1,
          lambda: array_ops.squeeze(x_whitened_norm, [-2, -1]),
          lambda: array_ops.squeeze(x_whitened_norm, [-1]))

      log_two_pi = constant_op.constant(math.log(2 * math.pi), dtype=self.dtype)
      k = math_ops.cast(self._k, self.dtype)
      log_pdf_value = (
          -math_ops.log(self._sigma_det) -k * log_two_pi - x_whitened_norm) / 2
      final_shaped_value = control_flow_ops.cond(
          x_rank <= sigma_rank - 1,
          lambda: log_pdf_value,
          lambda: array_ops.squeeze(log_pdf_value, [-1]))

      output_static_shape = x_centered.get_shape()[:-1]
      final_shaped_value.set_shape(output_static_shape)
      return final_shaped_value
Exemplo n.º 10
0
 def _batch_sqrt_solve(self, rhs):
     return linalg_ops.batch_matrix_triangular_solve(self._chol,
                                                     rhs,
                                                     lower=True)
Exemplo n.º 11
0
 def _batch_sqrt_solve(self, rhs):
   return linalg_ops.batch_matrix_triangular_solve(self._chol, rhs, lower=True)
Exemplo n.º 12
0
 def _inverse(self, x):
   x -= self.loc
   x, sample_shape = self.shaper.make_batch_of_event_sample_matrices(x)
   x = linalg_ops.batch_matrix_triangular_solve(self.scale, x)
   x = self.shaper.undo_make_batch_of_event_sample_matrices(x, sample_shape)
   return x