Exemple #1
0
  def __call__(self, x: tf.Tensor, sample_axis: int = "default", **kwargs):
    """
    :param sample_axis: Specify an axis of inputs x as corresponding 1-to-1 with
           sample-specific slices of weight tensor w when computing tensor dot
           products.
    """
    if sample_axis == "default":
      sample_axis = self.sample_axis

    feat = x if self.basis is None else self.basis(x, **kwargs)
    if sample_axis is None:
      batch_axes = None
    else:
      assert len(self.sample_shape), "Received sample_axis but self.weights has" \
                                      " no dedicated axis for samples; this" \
                                      " usually implies that sample_shape=[]."

      ndims_x = len(get_inducing_shape(x) if
                    isinstance(x, InducingVariables) else x.shape)

      batch_axes = [-3, normalize_axis(sample_axis, ndims_x) - ndims_x]

    # Batch-axes notwithstanding, shape(vals) = [N, S] (after move_axis)
    vals = move_axis(batch_tensordot(self.weights,
                                     feat,
                                     axes=[-1, -1],
                                     batch_axes=batch_axes),
                     len(self.sample_shape),  # axis=-2 houses scalar output 1
                     -1)

    return vals if self.mean_function is None else vals + self.mean_function(x)
Exemple #2
0
def _linear_fallback(Z: TensorLike,
                     u: TensorLike,
                     f: TensorLike,
                     *,
                     L: TensorLike = None,
                     diag: TensorLike = None,
                     basis: AbstractBasis = None,
                     **kwargs):

    u_shape = tuple(u.shape)
    f_shape = tuple(f.shape)
    assert u_shape[-1] == 1, "Recieved multiple output features"
    assert u_shape == f_shape[-len(u_shape):], "Incompatible shapes detected"

    # Prepare diagonal term
    if diag is None:  # used by <GPflow.conditionals>
        diag = default_jitter()
    if isinstance(diag, float):
        diag = tf.convert_to_tensor(diag, dtype=f.dtype)
    diag = tf.expand_dims(diag, axis=-1)  # [M, 1] or [1, 1] or [1]

    # Extract "features" of Z
    if basis is None:
        if isinstance(Z, inducing_variables.InducingVariables):
            feat = inducing_to_tensor(Z)  # [M, D]
        else:
            feat = Z
    else:
        feat = basis(Z)  # [M, D] (maybe a different "D" than above)

    # Compute error term and matrix square root $Cov(u, u)^{1/2}$
    err = swap_axes(u - f, -3, -1)  # [1, M, S]
    err -= tf.sqrt(diag) * tf.random.normal(err.shape, dtype=err.dtype)
    M, D = feat.shape[-2:]
    if L is None:
        if D < M:
            feat_iDiag = feat * tf.math.reciprocal(diag)
            S = tf.matmul(feat_iDiag, feat, transpose_a=True)  # [D, D]
            L = tf.linalg.cholesky(S + tf.eye(S.shape[-1], dtype=S.dtype))
        else:
            K = tf.matmul(feat, feat, transpose_b=True)  # [M, M]
            K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag[..., 0])
            L = tf.linalg.cholesky(K)
    else:
        assert L.shape[-1] == min(M, D)  # TODO: improve me

    # Solve for $Cov(u, u)^{-1}(u - f(Z))$
    if D < M:
        feat_iDiag = feat * tf.math.reciprocal(diag)
        weights = tf.linalg.adjoint(
            tf.linalg.cholesky_solve(
                L, tf.matmul(feat_iDiag, err, transpose_a=True)))
    else:
        iK_err = tf.linalg.cholesky_solve(L, err)  # [S, M, 1]
        weights = tf.matmul(iK_err, feat, transpose_a=True)  # [S, 1, D]

    return DenseSampler(basis=basis,
                        weights=move_axis(weights, -2, -3),
                        **kwargs)
def _Kuu_depthwise_conv2d(feat: DepthwiseInducingImages,
                          kern: DepthwiseConv2d,
                          jitter: float = 0.0):

  # Prepare scaled inducing patches; shape(Zp) = [channels_in, M, patch_len]
  Zp = move_axis(kern.kernel.scale(feat.as_patches), -2, 0)
  r2 = square_distance(Zp, None)
  _Kuu = tf.reduce_mean(kern.kernel.K_r2(r2), axis=0)  # [M, M]
  return tf.linalg.set_diag(_Kuu, tf.linalg.diag_part(_Kuu) + jitter)
Exemple #4
0
 def get_patches(self, X: TensorType, full_spatial: bool = False):
   """
   Returns the patches used by a 2d depthwise convolution.
   """
   patches = super().get_patches(X, full_spatial=full_spatial)
   channels_in = X.shape[-3 if self.data_format == "NCHW" else -1]
   depthwise_patches = tf.reshape(patches,
                                  list(patches.shape[:-1]) + [-1, channels_in])
   return move_axis(depthwise_patches, -2, -1)
Exemple #5
0
  def K(self, X: tf.Tensor, X2: tf.Tensor = None, full_spatial: bool = False):
    P = self.get_patches(X, full_spatial=full_spatial)
    P2 = P if X2 is None else self.get_patches(X2, full_spatial=full_spatial)

    # TODO: Temporary hack, use of self.kernel should be deprecated
    K = move_axis(
          tf.linalg.diag_part(
            move_axis(self.kernel.K(P, P2), P.shape.ndims - 2, -2)), -1, 0)

    if full_spatial:
      return K  # [channels_in, N, H1, W1, N2, H2, W2]

    # At this point, shape(K) = [N, num_patches, N2, num_patches]
    if self.weights is None:
      return tf.reduce_mean(K, axis=[0, -3, -1])

    K = batch_tensordot(K, self.weights, axes=[-1, 0], batch_axes=[0, 1])
    K = batch_tensordot(K, self.weights, axes=[-2, 0], batch_axes=[0, 1])
    return tf.reduce_mean(K, axis=0)
def _Kfu_depthwise_conv2d(feat: DepthwiseInducingImages,
                          kern: DepthwiseConv2d,
                          Xnew: tf.Tensor,
                          full_spatial: bool = False):

    if not isinstance(kern.kernel, kernels.Stationary):
        return _Kfu_depthwise_conv2d_fallback(feat, kern, Xnew, full_spatial)

    # Compute (squared) Mahalanobis distances between patches
    patch_shape = list(kern.patch_shape)
    channels_in = Xnew.shape[-3 if kern.data_format == "NCHW" else -1]
    channels_out = len(feat) * channels_in
    precis = tf.square(tf.math.reciprocal(kern.kernel.lengthscales))

    # Construct lengthscale filters [h, w, channels_in, 1]
    if kern.kernel.ard:  # notice the transpose!
        assert tuple(precis.shape) == (channels_in,
                                       tf.reduce_prod(patch_shape))
        filters = tf.reshape(tf.transpose(precis),
                             patch_shape + [channels_in, 1])
    else:
        filters = tf.fill(patch_shape + [channels_in, 1], precis)

    ZZ = tf.nn.depthwise_conv2d(input=tf.square(feat.as_images),
                                filter=filters,
                                strides=[1, 1, 1, 1],
                                padding="VALID")  # [M, 1, 1, channels_in]

    r2 = tf.reshape(move_axis(ZZ, 0, -1), [1, 1, 1, channels_out])

    X = tf.reshape(Xnew, [-1] + list(Xnew.shape)[-3:])  # stack as 4d images
    r2 += tf.repeat(kern.convolve(tf.square(X), filters), len(feat), axis=-1)

    filters *= feat.as_filters  # [h, w, channels_in, M]
    r2 -= 2 * kern.convolve(X, filters)  # [N, height_out, width_out, chan_out]

    Kxz = kern.kernel.K_r2(r2)
    if full_spatial:
        Kxz = tf.reduce_mean(tf.reshape(
            Kxz,
            list(Kxz.shape[:-1]) + [channels_in, -1]),
                             axis=-2)  # average over input channels
    else:
        Kxz = tf.reshape(Kxz,
                         list(Kxz.shape[:-3]) + [-1, len(feat)])  # [N, P, M]
        if kern.weights is None:
            Kxz = tf.reduce_mean(Kxz, axis=-2)
        else:
            div = tf.cast(1 / channels_in, Kxz.dtype)
            Kxz = div * tf.tensordot(Kxz, tf.reshape(kern.weights, [-1]),
                                     [-2, -1])

    # Undo stacking of Xnew as 4d images X
    return tf.reshape(Kxz, list(Xnew.shape[:-3]) + list(Kxz.shape[1:]))
Exemple #7
0
def _exact_independent(kern: kernels.MultioutputKernel,
                       Z: TensorLike,
                       u: TensorLike,
                       f: TensorLike,
                       *,
                       L: TensorLike = None,
                       diag: TensorLike = None,
                       basis: AbstractBasis = None,
                       multioutput_axis: int = 0,
                       **kwargs):
    """
  Return (independent) pathwise updates for each of the latent prior processes
  $f$ subject to the condition $p(f | u) = N(f | u, diag)$ on $f = f(Z)$.
  """
    u_shape = tuple(u.shape)
    f_shape = tuple(f.shape)
    assert u_shape[
        -1] == kern.num_latent_gps, "Num. outputs != num. latent GPs"
    assert u_shape == f_shape[-len(u_shape):], "Incompatible shapes detected"
    if basis is None:  # finite-dimensional basis used to express the update
        basis = kernel_basis(kern, centers=Z)

    # Prepare diagonal term
    if diag is None:  # used by <GPflow.conditionals>
        diag = default_jitter()
    if isinstance(diag, float):
        diag = tf.convert_to_tensor(diag, dtype=f.dtype)
    diag = tf.expand_dims(diag, axis=-1)  # ([L] or []) + ([M] or []) + [1]

    # Compute error term and matrix square root $Cov(u, u)^{1/2}$
    err = swap_axes(u - f, -3, -1)  # [L, M, S]
    err -= tf.sqrt(diag) * tf.random.normal(err.shape, dtype=err.dtype)
    if L is None:
        if isinstance(Z, inducing_variables.InducingVariables):
            K = covariances.Kuu(Z, kern, jitter=0.0)
        else:
            K = kern(Z, full_cov=True, full_output_cov=False)
        K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag[..., 0])
        L = tf.linalg.cholesky(K)

    # Solve for $Cov(u, u)^{-1}(u - f(Z))$
    weights = move_axis(tf.linalg.cholesky_solve(L, err), -1, -3)  # [S, L, M]
    return MultioutputDenseSampler(basis=basis,
                                   weights=weights,
                                   multioutput_axis=multioutput_axis,
                                   **kwargs)
    def __call__(self, x: TensorType, multioutput_axis: int = None, **kwargs):
        self._maybe_initialize(x, **kwargs)
        if isinstance(x, InducingVariables):  # TODO: Allow this behavior?
            x = inducing_to_tensor(x)

        # Compute (batch) tensor dot product
        batch_axes = None if (
            multioutput_axis is None) else [0, multioutput_axis]
        proj = move_axis(
            batch_tensordot(self.weights,
                            x,
                            axes=[-1, -1],
                            batch_axes=batch_axes), 1, -1)

        ndims = proj.shape.ndims
        feat = tf.cos(proj + expand_to(self.biases, axis=1, ndims=ndims))
        return expand_to(self.output_scale, axis=1,
                         ndims=ndims) * feat  # [L, N, B]
def _Kfu_depthwise_conv2d_fallback(feat: DepthwiseInducingImages,
                                   kern: DepthwiseConv2d,
                                   Xnew: tf.Tensor,
                                   full_spatial: bool = False):

    Zp = feat.as_patches  # [M, channels_in, patch_len]
    Xp = kern.get_patches(Xnew, full_spatial=full_spatial)
    r2 = tf.reduce_sum(
        tf.math.squared_difference(  # compute square distances
            tf.expand_dims(kern.kernel.scale(Xp), -Zp.shape.ndims),
            kern.kernel.scale(Zp)),
        axis=-1)

    Kxz = kern.kernel.K_r2(r2)
    if full_spatial:  # convert to 4D image format as [N, H, W, channels_in * M]
        return tf.reshape(move_axis(Kxz, -1, -2), list(Kxz.shape[:-2]) + [-1])

    if kern.weights is None:  # reduce over channels and patches
        return tf.reduce_mean(Kxz, axis=[-3, -1])

    return tf.tensordot(kern.weights, Kxz, axes=[(0, 1), (-3, -1)])
    def initialize(self, x, dtype: Any = None):
        if isinstance(x, inducing_variables.InducingImages):
            x = x.as_images

        if dtype is None:
            dtype = x.dtype

        channels_out = self.kernel.channels_in * self.num_bases
        self._biases = bias_initializer(self.kernel.kernel,
                                        channels_out,
                                        dtype=dtype)

        patch_size = self.kernel.patch_shape[0] * self.kernel.patch_shape[1]
        batch_shape = [self.kernel.channels_in, self.num_bases]
        weights = weight_initializer(self.kernel.kernel,
                                     patch_size,
                                     batch_shape=batch_shape,
                                     dtype=dtype)

        self._filters = tf.reshape(move_axis(weights, -1, 0),
                                   self.kernel.patch_shape + batch_shape)
Exemple #11
0
  def __call__(self, x: tf.Tensor, sample_axis: int = "default", **kwargs):
    """
    :param sample_axis: Specify an axis of inputs x as corresponding 1-to-1 with
           sample-specific slices of weight tensor w when computing tensor dot
           products.

    TODO: Improve hard-coding of multioutput-/sample-axis of weights.
    """
    if sample_axis == "default":
      sample_axis = self.sample_axis

    if self.multioutput_axis is None:
      batch_w = []  # batch axes for w
      batch_x = []  # batch axes for x
    else:
      batch_w = [-2]
      batch_x = [self.multioutput_axis]

    if sample_axis is not None:
      assert len(self.sample_shape), "Received sample_axis but self.weights has" \
                                      " no dedicated axis for samples; this" \
                                      " usually implies that sample_shape=[]."
      batch_w.append(-3)

      # TODO: If basis(x) grows the rank of x, it should only do so from the
      #       left, such that the negative i-th axis (i > 1) remains the same.
      ndims_x = len(get_inducing_shape(x) if
                    isinstance(x, InducingVariables) else x.shape)
      batch_x.append(normalize_axis(sample_axis, ndims_x) - ndims_x)

    feat = x if self.basis is None else self.basis(x, **kwargs)
    vals = move_axis(batch_tensordot(self.weights,  # output features go last
                                     feat,
                                     axes=[-1, -1],
                                     batch_axes=[batch_w, batch_x]),
                     len(self.sample_shape),  # axis=-2 houses multioutputs L
                     -1)

    return vals if self.mean_function is None else vals + self.mean_function(x)
 def as_filters(self) -> tf.Tensor:
     return move_axis(self.as_images, 0, -1)