def build_linear_operator_zeros(shape, dtype=None, seed=None, name=None):
    """Build an instance of `LinearOperatorZeros`.

  Args:
    shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, h, w]`, where
      `b0...bn` are batch dimensions `h` and `w` are the height and width of the
      matrix represented by the `LinearOperator`.
    dtype: `tf.dtype` of the `LinearOperator`.
    seed: Python integer to seed the random number generator.
    name: str, name for `tf.name_scope`.

  Returns:
    operator: Instance of `tf.linalg.LinearOperatorZeros`.
  """
    del seed  # Unused.
    with tf.name_scope(name or 'build_linear_operator_zeros'):
        batch_shape, rows, cols = ps.split(shape,
                                           num_or_size_splits=[-1, 1, 1])
        num_rows, num_cols = rows[0], cols[0]
        is_square = num_rows == num_cols
        return tf.linalg.LinearOperatorZeros(num_rows,
                                             num_cols,
                                             batch_shape=batch_shape,
                                             is_square=is_square,
                                             is_self_adjoint=is_square,
                                             dtype=dtype)
Exemplo n.º 2
0
def expand_dims(x, axis, name=None):
    """Like `tf.expand_dims` but accepts a vector of axes to expand."""
    with tf.name_scope(name or 'expand_dims'):
        x = tf.convert_to_tensor(x, name='x')
        axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis')
        nx = prefer_static.rank(x)
        na = prefer_static.size(axis)
        is_neg_axis = axis < 0
        k = prefer_static.reduce_sum(
            prefer_static.cast(is_neg_axis, axis.dtype))
        axis = prefer_static.where(is_neg_axis, axis + nx, axis)
        axis = prefer_static.sort(axis)
        axis_neg, axis_pos = prefer_static.split(axis, [k, -1])
        idx = prefer_static.argsort(prefer_static.concat([
            axis_pos,
            prefer_static.range(nx),
            axis_neg,
        ],
                                                         axis=0),
                                    stable=True)
        shape = prefer_static.pad(prefer_static.shape(x),
                                  paddings=[[na - k, k]],
                                  constant_values=1)
        shape = prefer_static.gather(shape, idx)
        return tf.reshape(x, shape)
Exemplo n.º 3
0
  def _broadcast(self, x, sample_shape):
    """Broadcasts x to target batch_shape.

    Specifically, x is broadcasted to have shape
    `sample_shape + target_batch + ndims`
    where target_batch == self.batch_shape except along the concatenation axis.

    Args:
      x: tf.Tensor with shape sample_shape + batch_shape + ndims.
      sample_shape: sample_shape of the input tensor.

    Returns:
      Broadcasted tensor.
    """
    x_shape = ps.shape(x)

    batch_shape = self._calculate_batch_shape()
    sample_batch_ndims = ps.shape(sample_shape)[0] + ps.shape(batch_shape)[0]
    rest_tensor = ps.split(x_shape,
                           [sample_batch_ndims, ps.rank(x)-sample_batch_ndims])
    target_shape = ps.concat([sample_shape,
                              batch_shape[:self._axis],
                              [x_shape[self._axis + ps.shape(sample_shape)[0]]],
                              batch_shape[self._axis+1:],
                              rest_tensor[1]], axis=0)
    return tf.broadcast_to(x, target_shape)
def _linear_operator_zeros(shape, dtype=None, name=None):
    """Build an instance of `LinearOperatorZeros`.

  Args:
    shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, h, w]`, where
      `b0...bn` are batch dimensions `h` and `w` are the height and width of the
      matrix represented by the `LinearOperator`.
    dtype: `tf.dtype` of the `LinearOperator`.
    name: str, name for `tf.name_scope`.
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.
  """
    with tf.name_scope(name or 'linear_operator_zeros'):
        batch_shape, rows, cols = ps.split(shape,
                                           num_or_size_splits=[-1, 1, 1])
        num_rows, num_cols = rows[0], cols[0]
        is_square = num_rows == num_cols
        return tf.linalg.LinearOperatorZeros(num_rows,
                                             num_cols,
                                             batch_shape=batch_shape,
                                             is_square=is_square,
                                             is_self_adjoint=is_square,
                                             dtype=dtype)
        # Tell Python that this fn is really a (trivial) generator.
        yield  # pylint: disable=unreachable
Exemplo n.º 5
0
def _im2row_index(input_shape,
                  block_shape,
                  slice_step=(1, 1),
                  data_format='NHWC',
                  padding='VALID',
                  dtype=tf.int64,
                  name=None):
    """Computes indexes into a flattened image for building `im2col`."""
    with tf.name_scope(name or 'im2row_index'):
        # 1) Process input arguments.
        batch_shape, s3, s2, s1 = prefer_static.split(
            prefer_static.cast(input_shape, tf.int32),
            num_or_size_splits=[-1, 1, 1, 1])
        fh, fw = _split_pair(block_shape)
        sh, sw = _split_pair(slice_step)
        data_format = _validate_data_format(data_format)
        padding = _validate_padding(padding)

        # 2) Assemble all block start positions as indexes into the flattened image.
        if data_format == 'NHWC':
            h, w, c = s3[0], s2[0], s1[0]
            # start_idx.shape = [fh, fw, c]
            start_idx = _cartesian_add([
                prefer_static.range(c * w * fh, delta=c * w, dtype=dtype),
                prefer_static.range(c * fw, delta=c, dtype=dtype),
                prefer_static.range(c, delta=1, dtype=dtype),
            ])
        elif data_format == 'NCHW':
            c, h, w = s3[0], s2[0], s1[0]
            # start_idx.shape = [c, fh, fw]
            start_idx = _cartesian_add([
                prefer_static.range(w * h * c, delta=w * h, dtype=dtype),
                prefer_static.range(w * fh, delta=w, dtype=dtype),
                prefer_static.range(fw, delta=1, dtype=dtype),
            ])
        else:
            assert False  # Can't be here.

        # 3) Assemble all block offsets (into flattened image).
        if padding == 'VALID':
            eh = h - fh + 1  # extent height
            ew = w - fw + 1  # extent width
            # offset_idx.shape = [eh // sh, ew // sw]
            offset_idx = _cartesian_add([
                prefer_static.range(w * eh, delta=w * sh, dtype=dtype),
                prefer_static.range(ew, delta=sw, dtype=dtype),
            ])
            if data_format == 'NHWC':
                offset_idx *= c
            oh = eh // sh  # out height
            ow = ew // sw  # out width
        else:
            assert False  # Can't be here.

        # 4) Combine block start/offset pairs.
        # shape = [(eh // sh) * (ew // sw), fh * fw * c]
        idx = _cartesian_add([offset_idx, start_idx])
        new_shape = [oh, ow, fh * fw * c]
        new_shape = prefer_static.concat([batch_shape, new_shape], axis=0)
        return idx, new_shape
Exemplo n.º 6
0
 def expand_dims_(x):
     """Implementation of `expand_dims`."""
     with tf.name_scope(name or 'expand_dims'):
         x = tf.convert_to_tensor(x, name='x')
         new_axis = tf.convert_to_tensor(axis,
                                         dtype_hint=tf.int32,
                                         name='axis')
         nx = prefer_static.rank(x)
         na = prefer_static.size(new_axis)
         is_neg_axis = new_axis < 0
         k = prefer_static.reduce_sum(
             prefer_static.cast(is_neg_axis, new_axis.dtype))
         new_axis = prefer_static.where(is_neg_axis, new_axis + nx,
                                        new_axis)
         new_axis = prefer_static.sort(new_axis)
         axis_neg, axis_pos = prefer_static.split(new_axis, [k, -1])
         idx = prefer_static.argsort(prefer_static.concat([
             axis_pos,
             prefer_static.range(nx),
             axis_neg,
         ],
                                                          axis=0),
                                     stable=True)
         shape = prefer_static.pad(prefer_static.shape(x),
                                   paddings=[[na - k, k]],
                                   constant_values=1)
         shape = prefer_static.gather(shape, idx)
         return tf.reshape(x, shape)
Exemplo n.º 7
0
def batchify_op(op, op_min_input_ndims, x, *other_op_args):
    """Reshape `op` input `x` to be a vec of `op_min_input_ndims`-rank tensors."""
    if x.shape.rank == op_min_input_ndims + 1:
        # Input is already a vector of `op_min_input_ndims`-rank tensors.
        return op(x, *other_op_args)
    batch_shape, op_shape = ps.split(
        ps.shape(x), num_or_size_splits=[-1, op_min_input_ndims])
    flat_shape = ps.pad(op_shape, paddings=[[1, 0]], constant_values=-1)
    y = tf.reshape(x, flat_shape)
    y = op(y, *other_op_args)
    unflat_shape = ps.concat([
        batch_shape,
        ps.shape(y)[1:],
    ], axis=0)
    y = tf.reshape(y, unflat_shape)
    return y
def _trainable_linear_operator_tril(shape,
                                    scale_initializer=1e-2,
                                    diag_bijector=None,
                                    dtype=None,
                                    name=None):
    """Build a trainable `LinearOperatorLowerTriangular` instance.

  Args:
    shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, d]`, where
      `b0...bn` are batch dimensions and `d` is the length of the diagonal.
    scale_initializer: Variables are initialized with samples from
      `Normal(0, scale_initializer)`.
    diag_bijector: Bijector to apply to the diagonal of the operator.
    dtype: `tf.dtype` of the `LinearOperator`.
    name: str, name for `tf.name_scope`.
  Yields:
    *parameters: sequence of `trainable_state_util.Parameter` namedtuples.
      These are intended to be consumed by
      `trainable_state_util.as_stateful_builder` and
      `trainable_state_util.as_stateless_builder` to define stateful and
      stateless variants respectively.
  """
    with tf.name_scope(name or 'trainable_linear_operator_tril'):
        if dtype is None:
            dtype = dtype_util.common_dtype([scale_initializer],
                                            dtype_hint=tf.float32)

        scale_initializer = tf.convert_to_tensor(scale_initializer,
                                                 dtype=dtype)
        diag_bijector = diag_bijector or _DefaultScaleDiagonal()
        batch_shape, dim = ps.split(shape, num_or_size_splits=[-1, 1])

        scale_tril_bijector = fill_scale_tril.FillScaleTriL(
            diag_bijector, diag_shift=tf.zeros([], dtype=dtype))
        scale_tril = yield trainable_state_util.Parameter(
            init_fn=lambda seed: scale_tril_bijector(  # pylint: disable=g-long-lambda
                samplers.normal(mean=0.,
                                stddev=scale_initializer,
                                shape=ps.concat(
                                    [batch_shape, dim * (dim + 1) // 2],
                                    axis=0),
                                seed=seed,
                                dtype=dtype)),
            name='scale_tril',
            constraining_bijector=scale_tril_bijector)
        return tf.linalg.LinearOperatorLowerTriangular(tril=scale_tril,
                                                       is_non_singular=True)
Exemplo n.º 9
0
 def __init__(self, input_shape, block_size=2, validate_args=False, name=None):
   parameters = dict(locals())
   self._block_size = block_size
   _, h, w, c = prefer_static.split(input_shape, [-1, 1, 1, 1])
   h, w, c = h[0], w[0], c[0]
   n = self._block_size
   b = [
       reshape.Reshape(
           event_shape_out=[h * n, w * n, c // n**2],
           event_shape_in=[h, n, w, n, c // n**2]),
       transpose.Transpose(perm=[0, 3, 1, 4, 2]),
       reshape.Reshape(
           event_shape_in=[h, w, c],
           event_shape_out=[h, w, c // n**2, n, n]),
   ]
   super(Expand, self).__init__(b, name=name or 'Expand',
                                parameters=parameters)
def build_trainable_linear_operator_tril(shape,
                                         scale_initializer=1e-2,
                                         diag_bijector=None,
                                         dtype=None,
                                         seed=None,
                                         name=None):
    """Build a trainable `LinearOperatorLowerTriangular` instance.

  Args:
    shape: Shape of the `LinearOperator`, equal to `[b0, ..., bn, d]`, where
      `b0...bn` are batch dimensions and `d` is the length of the diagonal.
    scale_initializer: Variables are initialized with samples from
      `Normal(0, scale_initializer)`.
    diag_bijector: Bijector to apply to the diagonal of the operator.
    dtype: `tf.dtype` of the `LinearOperator`.
    seed: Python integer to seed the random number generator.
    name: str, name for `tf.name_scope`.

  Returns:
    operator: Trainable instance of `tf.linalg.LinearOperatorLowerTriangular`.
  """
    with tf.name_scope(name or 'build_trainable_linear_operator_tril'):
        if dtype is None:
            dtype = dtype_util.common_dtype([scale_initializer],
                                            dtype_hint=tf.float32)

        scale_initializer = tf.convert_to_tensor(scale_initializer,
                                                 dtype=dtype)
        diag_bijector = diag_bijector or _DefaultScaleDiagonal()
        batch_shape, dim = ps.split(shape, num_or_size_splits=[-1, 1])

        scale_tril_bijector = fill_scale_tril.FillScaleTriL(
            diag_bijector, diag_shift=tf.zeros([], dtype=dtype))
        flat_initial_scale = samplers.normal(
            mean=0.,
            stddev=scale_initializer,
            shape=ps.concat([batch_shape, dim * (dim + 1) // 2], axis=0),
            seed=seed,
            dtype=dtype)
        return tf.linalg.LinearOperatorLowerTriangular(
            tril=tfp_util.TransformedVariable(
                scale_tril_bijector.forward(flat_initial_scale),
                bijector=scale_tril_bijector,
                name='tril'),
            is_non_singular=True)
Exemplo n.º 11
0
def _compute_fans_from_shape(shape, batch_ndims=0):
  """Extracts `fan_in, fan_out` from specified shape `Tensor`."""
  # Ensure shape is a vector of length >=2.
  num_pad = prefer_static.maximum(0, 2 - prefer_static.size(shape))
  shape = prefer_static.pad(
      shape, paddings=[[0, num_pad]], constant_values=1)
  (
      batch_shape,  # pylint: disable=unused-variable
      extra_shape,
      fan_in,
      fan_out,
  ) = prefer_static.split(shape, [batch_ndims, -1, 1, 1])
  # The following logic is primarily intended for convolutional layers which
  # have spatial semantics in addition to input/output channels.
  receptive_field_size = prefer_static.reduce_prod(extra_shape)
  fan_in = fan_in[0] * receptive_field_size
  fan_out = fan_out[0] * receptive_field_size
  return fan_in, fan_out
Exemplo n.º 12
0
def convolution_batch(x,
                      kernel,
                      rank,
                      strides,
                      padding,
                      data_format=None,
                      dilations=None,
                      name=None):
    """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
    if rank != 2:
        raise NotImplementedError(
            'Argument `rank` currently only supports `2`; '
            'saw "{}".'.format(rank))
    if data_format is not None and data_format.upper() != 'NHWBC':
        raise ValueError(
            'Argument `data_format` currently only supports "NHWBC"; '
            'saw "{}".'.format(data_format))
    with tf.name_scope(name or 'conv2d_nhwbc'):
        # Prepare arguments.
        [
            rank,
            _,  # strides
            padding,
            dilations,
            data_format,
        ] = prepare_conv_args(rank, strides, padding, dilations)
        strides = prepare_strides(strides, rank + 2, arg_name='strides')

        dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel')

        # Step 1: Transpose and double flatten kernel.
        # kernel.shape = B + F + [c, c']. Eg: [b, fh, fw, c, c']
        kernel_shape = prefer_static.shape(kernel)
        kernel_batch_shape, kernel_event_shape = prefer_static.split(
            kernel_shape, num_or_size_splits=[-1, rank + 2])
        kernel_batch_size = prefer_static.reduce_prod(kernel_batch_shape)
        kernel_ndims = prefer_static.rank(kernel)
        kernel_batch_ndims = kernel_ndims - rank - 2
        perm = prefer_static.concat([
            prefer_static.range(kernel_batch_ndims, kernel_batch_ndims + rank),
            prefer_static.range(0, kernel_batch_ndims),
            prefer_static.range(kernel_batch_ndims + rank, kernel_ndims),
        ],
                                    axis=0)  # Eg, [1, 2, 0, 3, 4]
        kernel = tf.transpose(kernel, perm=perm)  # F + B + [c, c']
        kernel = tf.reshape(kernel,
                            shape=prefer_static.concat([
                                kernel_event_shape[:rank],
                                [
                                    kernel_batch_size * kernel_event_shape[-2],
                                    kernel_event_shape[-1]
                                ],
                            ],
                                                       axis=0))  # F + [bc, c']

        # Step 2: Double flatten x.
        # x.shape = N + D + B + [c]
        x_shape = prefer_static.shape(x)
        [
            x_sample_shape,
            x_rank_shape,
            x_batch_shape,
            x_channel_shape,
        ] = prefer_static.split(
            x_shape, num_or_size_splits=[-1, rank, kernel_batch_ndims, 1])
        x = tf.reshape(
            x,  # N + D + B + [c]
            shape=prefer_static.concat([
                [prefer_static.reduce_prod(x_sample_shape)],
                x_rank_shape,
                [
                    prefer_static.reduce_prod(x_batch_shape) *
                    prefer_static.reduce_prod(x_channel_shape)
                ],
            ],
                                       axis=0))  # [n] + D + [bc]

        # Step 3: Apply convolution.
        y = tf.nn.depthwise_conv2d(x,
                                   kernel,
                                   strides=strides,
                                   padding=padding,
                                   data_format='NHWC',
                                   dilations=dilations)
        #  SAME: y.shape = [n, h,      w,      bcc']
        # VALID: y.shape = [n, h-fh+1, w-fw+1, bcc']

        # Step 4: Reshape/reduce for output.
        y_shape = prefer_static.shape(y)
        y = tf.reshape(y,
                       shape=prefer_static.concat(
                           [
                               x_sample_shape,
                               y_shape[1:-1],
                               kernel_batch_shape,
                               kernel_event_shape[-2:],
                           ],
                           axis=0))  # N + D' + B + [c, c']
        y = tf.reduce_sum(y, axis=-2)  # N + D' + B + [c']

        return y
Exemplo n.º 13
0
    def op(x, kernel):
        input_dtype = dtype_util.common_dtype([x, kernel],
                                              dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel')

        batch_shape, event_shape = ps.split(ps.shape(x),
                                            num_or_size_splits=[-1, 3])
        xh, xw, c_in = ps.unstack(event_shape, num=3)
        fh, fw = filter_shape

        assertions = _maybe_validate_input_shapes(ps.shape(kernel),
                                                  channels_in=c_in,
                                                  filter_height=fh,
                                                  filter_width=fw,
                                                  validate_args=validate_args)

        with tf.control_dependencies(assertions):
            if tf.get_static_value(ps.rank(kernel)) == 2:
                flat_x = tf.reshape(x,
                                    shape=ps.concat([[-1], event_shape],
                                                    axis=0))
                flat_y = tf.nn.conv2d(x,
                                      filters=tf.reshape(
                                          kernel, shape=[fh, fw, c_in, -1]),
                                      strides=strides,
                                      padding=padding,
                                      data_format='NHWC',
                                      dilations=dilations)
                output_shape = ps.shape(flat_y)[-3:]
                return tf.reshape(flat_y,
                                  shape=ps.concat([batch_shape, output_shape],
                                                  axis=0))

            pad_values = [
                _get_conv_padding(xdim,
                                  filter_dim=k,
                                  stride=s,
                                  dilation=d,
                                  padding=padding)
                for (xdim, k, s,
                     d) in zip((xh, xw), filter_shape, strides, dilations)
            ]

            idx, shape = im2row_index(
                (xh + sum(pad_values[0]), xw + sum(pad_values[1]), c_in),
                block_shape=filter_shape,
                slice_step=strides,
                dilations=dilations,
                dtype=dtype)

            if padding == 'SAME':
                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x = tf.pad(x, paddings=paddings, constant_values=0)

            flat_shape = ps.pad(batch_shape,
                                paddings=[[0, 1]],
                                constant_values=-1)
            flat_x = tf.gather(tf.reshape(x, shape=flat_shape),
                               indices=idx,
                               axis=-1)
            im_x = tf.reshape(flat_x,
                              shape=ps.concat([batch_shape, shape], axis=0))
            return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
Exemplo n.º 14
0
def _resample_using_log_points(log_probs, sample_shape, log_points, name=None):
    """Resample from `log_probs` using supplied points in interval `[0, 1]`."""

    # We divide up the unit interval [0, 1] according to the provided
    # probability distributions using `cumulative_logsumexp`.
    # At the end of each division we place a 'marker'.
    # We use points on the unit interval supplied by caller.
    # We sort the combination of points and markers. The number
    # of points between the markers defining a division gives the number
    # of samples we require in that division.
    # For example, suppose `probs` is `[0.2, 0.3, 0.5]`.
    # We divide up `[0, 1]` using 3 markers:
    #
    #     |     |          |
    # 0.  0.2   0.5        1.0  <- markers
    #
    # Suppose we are given four points: [0.1, 0.25, 0.9, 0.75]
    # After sorting the combination we get:
    #
    # 0.1  0.25     0.75 0.9    <- points
    #  *  | *   |    *    *|
    # 0.   0.2 0.5         1.0  <- markers
    #
    # We have one sample in the first category, one in the second and
    # two in the last.
    #
    # All of these computations are carried out in batched form.

    with tf.name_scope(name or 'resample_using_log_points') as name:
        points_shape = ps.shape(log_points)
        batch_shape, [num_markers] = ps.split(ps.shape(log_probs),
                                              num_or_size_splits=[-1, 1])

        # `working_shape` specifies the total number of events
        # we will be generating.
        working_shape = ps.concat([sample_shape, batch_shape], axis=0)
        # `markers_shape` is the shape of the markers we temporarily insert.
        markers_shape = ps.concat([working_shape, [num_markers]], axis=0)

        markers = ps.concat([
            tf.ones(markers_shape, dtype=tf.int32),
            tf.zeros(points_shape, dtype=tf.int32)
        ],
                            axis=-1)
        log_marker_positions = tf.broadcast_to(
            log_cumsum_exp(log_probs, axis=-1), markers_shape)
        log_markers_and_points = ps.concat([log_marker_positions, log_points],
                                           axis=-1)
        # Stable sort is used to ensure that no points get sorted between
        # markers that have zero distance between them. This ensures that
        # there will never be a sample drawn whose probability is intended
        # to be zero even when a point falls on the edge of the
        # corresponding zero-width bucket.
        indices = tf.argsort(log_markers_and_points, axis=-1, stable=True)
        sorted_markers = tf.gather_nd(
            markers,
            indices[..., tf.newaxis],
            batch_dims=(ps.rank_from_shape(sample_shape) +
                        ps.rank_from_shape(batch_shape)))
        markers_and_samples = ps.cast(tf.cumsum(sorted_markers, axis=-1),
                                      dtype=tf.int32)
        markers_and_samples = tf.math.minimum(markers_and_samples,
                                              num_markers - np.int32(1))

        # Collect up samples, omitting markers.
        samples_mask = tf.equal(sorted_markers, 0)

        # The following block of code is equivalent to
        # `samples = markers_and_samples[samples_mask]` however boolean mask
        # indices are not supported by XLA.
        # Instead we use `argsort` to pick out the top `num_samples`
        # elements of `markers_and_samples` when sorted using `samples_mask`
        # as key.
        num_samples = points_shape[-1]
        sample_locations = tf.argsort(ps.cast(samples_mask, dtype=tf.int32),
                                      direction='DESCENDING',
                                      stable=True)
        samples = tf.gather_nd(markers_and_samples,
                               sample_locations[..., :num_samples, tf.newaxis],
                               batch_dims=(ps.rank_from_shape(sample_shape) +
                                           ps.rank_from_shape(batch_shape)))

        return tf.reshape(samples, points_shape)
Exemplo n.º 15
0
  def __init__(self,
               bijectors,
               block_sizes=None,
               validate_args=False,
               maybe_changes_size=True,
               name=None):
    """Creates the bijector.

    Args:
      bijectors: A non-empty list of bijectors.
      block_sizes: A 1-D integer `Tensor` with each element signifying the
        length of the block of the input vector to pass to the corresponding
        bijector. The length of `block_sizes` must be be equal to the length of
        `bijectors`. If left as None, a vector of 1's is used.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      maybe_changes_size: Python `bool` indicating that this bijector might
        change the event size. If this is known to be false and set
        appropriately, then this will lead to improved static shape inference
        when the block sizes are not statically known.
      name: Python `str`, name given to ops managed by this object. Default:
        E.g., `Blockwise([Exp(), Softplus()]).name ==
        'blockwise_of_exp_and_softplus'`.

    Raises:
      NotImplementedError: If there is a bijector with `event_ndims` > 1.
      ValueError: If `bijectors` list is empty.
      ValueError: If size of `block_sizes` does not equal to the length of
        bijectors or is not a vector.
    """
    parameters = dict(locals())
    if not name:
      name = 'blockwise_of_' + '_and_'.join([b.name for b in bijectors])
      name = name.replace('/', '')

    with tf.name_scope(name) as name:
      for b in bijectors:
        if (nest.is_nested(b.forward_min_event_ndims)
            or nest.is_nested(b.inverse_min_event_ndims)):
          raise ValueError('Bijectors must all be single-part.')
        elif isinstance(b.forward_min_event_ndims, int):
          if b.forward_min_event_ndims != b.inverse_min_event_ndims:
            raise ValueError('Rank-changing bijectors are not supported.')
          elif b.forward_min_event_ndims > 1:
            raise ValueError('Only scalar and vector event-shape '
                             'bijectors are supported at this time.')

      b_joint = joint_map.JointMap(list(bijectors), name='jointmap')

      block_sizes = (
          np.ones(len(bijectors), dtype=np.int32)
          if block_sizes is None else
          _validate_block_sizes(block_sizes, bijectors, validate_args))
      b_split = split.Split(
          block_sizes, name='split', validate_args=validate_args)

      if maybe_changes_size:
        i_block_sizes = _validate_block_sizes(
            ps.concat(b_joint.forward_event_shape_tensor(
                ps.split(block_sizes, len(bijectors))), axis=0),
            bijectors, validate_args)
        maybe_changes_size = not tf.get_static_value(
            ps.reduce_all(block_sizes == i_block_sizes))
      b_concat = invert.Invert(
          (split.Split(i_block_sizes, name='isplit')
           if maybe_changes_size else b_split),
          name='concat')

      self._maybe_changes_size = maybe_changes_size
      super(Blockwise, self).__init__(
          bijectors=[b_concat, b_joint, b_split],
          validate_args=validate_args,
          parameters=parameters,
          name=name)
Exemplo n.º 16
0
def _replace_event_shape_in_shape_tensor(input_shape, event_shape_in,
                                         event_shape_out, validate_args):
    """Replaces the rightmost dims in a `Tensor` representing a shape.

  Args:
    input_shape: a rank-1 `Tensor` of integers
    event_shape_in: the event shape expected to be present in rightmost dims
      of `shape_in`.
    event_shape_out: the event shape with which to replace `event_shape_in` in
      the rightmost dims of `input_shape`.
    validate_args: Python `bool` indicating whether arguments should
      be checked for correctness.

  Returns:
    output_shape: A rank-1 integer `Tensor` with the same contents as
      `input_shape` except for the event dims, which are replaced with
      `event_shape_out`.
  """
    output_tensorshape, is_validated = _replace_event_shape_in_tensorshape(
        tensorshape_util.constant_value_as_shape(input_shape), event_shape_in,
        event_shape_out)

    if (tensorshape_util.is_fully_defined(output_tensorshape)
            and (is_validated or not validate_args)):
        output_shape = ps.convert_to_shape_tensor(
            tensorshape_util.as_list(output_tensorshape),
            name='output_shape',
            dtype_hint=tf.int32)
        return output_shape, output_tensorshape

    event_shape_in_ndims = (
        ps.size(event_shape_in)
        if tensorshape_util.num_elements(event_shape_in.shape) is None else
        tensorshape_util.num_elements(event_shape_in.shape))
    input_non_event_shape, input_event_shape = ps.split(
        input_shape, num_or_size_splits=[-1, event_shape_in_ndims])

    additional_assertions = []
    if is_validated:
        pass
    elif validate_args:
        # Check that `input_event_shape` and `event_shape_in` are compatible in the
        # sense that they have equal entries in any position that isn't a `-1` in
        # `event_shape_in`. Note that our validations at construction time ensure
        # there is at most one such entry in `event_shape_in`.
        mask = event_shape_in >= 0
        explicit_input_event_shape = tf.boolean_mask(input_event_shape,
                                                     mask=mask)
        explicit_event_shape_in = tf.boolean_mask(event_shape_in, mask=mask)
        additional_assertions.append(
            assert_util.assert_equal(
                explicit_input_event_shape,
                explicit_event_shape_in,
                message='Input `event_shape` does not match `event_shape_in`.')
        )
        # We don't explicitly additionally verify
        # `tf.size(input_shape) > tf.size(event_shape_in)` since `tf.split`
        # already makes this assertion.

    with tf.control_dependencies(additional_assertions):
        output_shape = ps.concat([input_non_event_shape, event_shape_out],
                                 axis=0,
                                 name='output_shape')

    return output_shape, output_tensorshape
Exemplo n.º 17
0
def _split_pair(x):
    """Splits a length two vector into two scalars."""
    x = prefer_static.cast(x, dtype=tf.int32)
    a, b = prefer_static.split(x, num_or_size_splits=[1, 1])
    return a[0], b[0]
Exemplo n.º 18
0
def im2row_index(input_shape,
                 block_shape,
                 rank=2,
                 slice_step=(1, 1),
                 dilations=(1, 1),
                 dtype=tf.int32,
                 transpose=False,
                 validate_args=False,
                 name=None):
    """Computes indexes into a flattened image for building `im2row`."""
    with tf.name_scope(name or 'im2row_index'):
        if tf.get_static_value(rank) != 2:
            raise NotImplementedError(
                'Argument `rank` currently only supports `2`; '
                'saw "{}".'.format(rank))
        fh, fw = prepare_tuple_argument(block_shape,
                                        n=rank,
                                        arg_name='block_shape',
                                        validate_args=validate_args)
        sh, sw = prepare_tuple_argument(slice_step,
                                        n=rank,
                                        arg_name='slice_step',
                                        validate_args=validate_args)
        dh, dw = prepare_tuple_argument(dilations,
                                        n=rank,
                                        arg_name='dilations',
                                        validate_args=validate_args)

        # 1) Process input arguments.
        batch_shape, h, w, c = ps.split(ps.reshape(ps.cast(input_shape,
                                                           dtype=dtype),
                                                   shape=[-1]),
                                        num_or_size_splits=[-1, 1, 1, 1])
        h, w, c = h[0], w[0], c[0]

        tot_fh = dh * (fh - 1) + 1
        tot_fw = dw * (fw - 1) + 1

        # 2) Assemble all block start positions as indexes into the flattened image.
        # start_idx.shape = [fh, fw, c]
        if transpose:
            last_element = lambda size, step: size - (size - 1) % step - 1
            w_step = c * dw
            h_step = c * w * dh
            last_w = last_element(c * tot_fw, w_step)
            last_h = last_element(c * w * tot_fh, h_step)
            start_idx = cartesian_add([
                ps.range(last_h, -1, delta=-h_step, dtype=dtype),
                ps.range(last_w, -1, delta=-w_step, dtype=dtype),
                ps.range(c, delta=1, dtype=dtype),
            ])
        else:
            start_idx = cartesian_add([
                ps.range(c * w * tot_fh, delta=c * w * dh, dtype=dtype),
                ps.range(c * tot_fw, delta=c * dw, dtype=dtype),
                ps.range(c, delta=1, dtype=dtype),
            ])

        # 3) Assemble all block offsets (into flattened image).
        eh = h - tot_fh + 1
        ew = w - tot_fw + 1

        offset_idx = cartesian_add([
            ps.range(w * eh, delta=w * sh, dtype=dtype),
            ps.range(ew, delta=sw, dtype=dtype),
        ])

        offset_idx = offset_idx * c
        oh = (eh - 1) // sh + 1  # out height
        ow = (ew - 1) // sw + 1  # out width

        # 4) Combine block start/offset pairs.
        # shape = [(eh // sh) * (ew // sw), fh * fw * c]
        idx = cartesian_add([offset_idx, start_idx])
        new_shape = ps.concat(
            [batch_shape,
             ps.convert_to_shape_tensor([oh, ow, fh * fw * c])],
            axis=0)
        return idx, new_shape
Exemplo n.º 19
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  c_out, batch_shape,
                                                  event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x_pad = tf.pad(x, paddings=paddings, constant_values=0)

                ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1
                ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1

                def loop_body(i, outputs):
                    subkernel_ind = kernels_ind.read(i)
                    fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2)
                    eh = ex_h + fh_ - 1
                    ew = ex_w + fw_ - 1

                    subkernel_ind = ps.reshape(ps.reshape(
                        subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] +
                                               ps.range(c_in),
                                               shape=[-1])

                    k = tf.gather(kernel, subkernel_ind, axis=-2)
                    ind, shape = im2row_index([eh, ew, c_in],
                                              block_shape=(fh_, fw_),
                                              slice_step=(1, 1),
                                              dilations=dilations)
                    x_i = x_pad[..., :eh, :ew, :]
                    x_i_shape = ps.shape(x_i)
                    flat_shape = ps.pad(x_i_shape[:-3],
                                        paddings=[[0, 1]],
                                        constant_values=-1)
                    flat_x = tf.reshape(x_i, flat_shape)
                    x_ = tf.gather(flat_x, ind, axis=-1)
                    im_x = tf.reshape(
                        x_, ps.concat([x_i_shape[:-3], shape], axis=0))
                    outputs = outputs.write(
                        i,
                        tf.matmul(
                            im_x,
                            tf.reshape(
                                k,
                                ps.concat([
                                    kernel_batch, [1, fh_ * fw_ * c_in, c_out]
                                ],
                                          axis=0))))
                    return i + 1, outputs

                outputs = tf.TensorArray(dtype=input_dtype,
                                         infer_shape=False,
                                         size=1,
                                         dynamic_size=True)

                _, outputs = tf.while_loop(lambda i, _: i < sh * sw, loop_body,
                                           [0, outputs])

                y = outputs.concat()

                m = tf.reduce_prod(ps.shape(y)[:-3])
                y_ = tf.reshape(y,
                                shape=ps.concat([[m], ps.shape(y)[-3:]],
                                                axis=0))
                y2 = tf.batch_to_space(y_,
                                       strides,
                                       crops=tf.zeros([2, 2], dtype=tf.int64))
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)
                y2 = tf.reshape(
                    y2,
                    ps.concat([broadcast_batch_shape,
                               ps.shape(y2)[-3:]],
                              axis=0))

                if padding == 'VALID':
                    out_height = fh + sh * (xh - 1)
                    out_width = fw + sw * (xw - 1)
                elif padding == 'SAME':
                    out_height = xh * sh
                    out_width = xw * sw

                return y2[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
Exemplo n.º 20
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):

                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x,
                                                  kernel=kernel,
                                                  filter_shape=filter_shape,
                                                  strides=(strides, ) * rank,
                                                  padding=padding,
                                                  dilations=dilations,
                                                  c_out=c_out,
                                                  batch_shape=batch_shape,
                                                  event_shape=event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                x_pad_shape = ps.shape(x_pad)[:-3]
                flat_shape = ps.pad(x_pad_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.reshape(x_pad, shape=flat_shape)

                idx, s = im2row_index(
                    (xh + tf.reduce_sum(padding_vals[0]),
                     xw + tf.reduce_sum(padding_vals[1]), c_in),
                    block_shape=(sub_fh, sub_fw),
                    slice_step=(1, 1),
                    dilations=dilations)

                x_ = tf.gather(flat_x, indices=idx, axis=-1)
                im_x = tf.reshape(x_,
                                  shape=ps.concat([x_pad_shape, s], axis=0))

                # Add channels to subkernel indices
                idx_event = event_ind * [[c_in, 1]]
                idx_event_channels = (idx_event[tf.newaxis] + tf.stack(
                    [ps.range(c_in),
                     tf.zeros(
                         (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :])
                idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels,
                                                         block_shape=[c_in],
                                                         crops=[[0, 0]]),
                                       axis=0)
                idx_event_broadcast = tf.broadcast_to(
                    idx_event,
                    shape=ps.concat(
                        [kernel_batch, ps.shape(idx_event)], axis=0))

                # Add cartesian product of batch indices, since scatter_nd can only be
                # applied to leading dimensions.
                idx_batch = tf.stack(tf.meshgrid(*[
                    ps.range(b_, delta=1, dtype=dtype)
                    for b_ in tf.unstack(kernel_batch)
                ],
                                                 indexing='ij'),
                                     axis=ps.size(kernel_batch))

                idx_batch = tf.cast(idx_batch,
                                    dtype=dtype)  # empty tensor is float

                idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros(
                    (ps.shape(idx_event)[0], 1), dtype=dtype)
                idx_kernel = tf.concat(
                    [idx_batch_broadcast, idx_event_broadcast], axis=-1)

                kernel_mat = tf.scatter_nd(
                    idx_kernel,
                    updates=kernel,
                    shape=ps.cast(ps.concat([
                        kernel_batch,
                        [sub_fh * sub_fw * c_in, strides**2, c_out]
                    ],
                                            axis=0),
                                  dtype=dtype))

                kernel_mat = tf.reshape(
                    kernel_mat,
                    shape=ps.concat(
                        [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]],
                        axis=0))

                kernel_mat = kernel_mat[..., tf.newaxis, :, :]
                out = tf.matmul(im_x, kernel_mat)
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)

                if strides > 1:
                    tot_size = tf.reduce_prod(broadcast_batch_shape)
                    flat_out = tf.reshape(out,
                                          shape=ps.concat([[tot_size],
                                                           ps.shape(out)[-3:]],
                                                          axis=0))
                    out = tf.nn.depth_to_space(flat_out, block_size=strides)

                if padding == 'VALID':
                    out_height = fh + strides * (xh - 1)
                    out_width = fw + strides * (xw - 1)
                elif padding == 'SAME':
                    out_height = xh * strides
                    out_width = xw * strides

                out = out[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
                out = tf.reshape(
                    out,
                    shape=ps.concat([
                        broadcast_batch_shape, [out_height, out_width, c_out]
                    ],
                                    axis=0))
                return out
Exemplo n.º 21
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)
            kernel_shape = ps.shape(kernel)
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  kernel_shape[-1],
                                                  batch_shape, event_shape)

                idx, shape = im2row_index((xh * sh + sum(pad_values[0]),
                                           xw * sw + sum(pad_values[1]), c_in),
                                          block_shape=filter_shape,
                                          slice_step=(1, 1),
                                          dilations=dilations,
                                          dtype=dtype,
                                          transpose=True)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(pad_values,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                # Interleave the rows and columns of the input with rows and columns of
                # zeros equal to the number of strides.
                x_half_dilated = tf.concat([
                    tf.zeros(ps.concat([batch_shape, (xh * xw, sw - 1, c_in)],
                                       axis=0),
                             dtype=input_dtype),
                    tf.reshape(x,
                               shape=ps.concat(
                                   [batch_shape, (xh * xw, 1, c_in)], axis=0))
                ],
                                           axis=-2)
                y = tf.reshape(x_half_dilated,
                               shape=ps.concat(
                                   [batch_shape, (xh, 1, xw * sw, c_in)],
                                   axis=0))

                x = tf.reshape(tf.concat([
                    tf.zeros(ps.concat(
                        [batch_shape, (xh, sh - 1, xw * sw, c_in)], axis=0),
                             dtype=input_dtype), y
                ],
                                         axis=-3),
                               shape=ps.concat(
                                   [batch_shape, (xh * sh, xw * sw, c_in)],
                                   axis=0))
                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                flat_shape = ps.pad(batch_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.gather(tf.reshape(x_pad, shape=flat_shape),
                                   indices=idx,
                                   axis=-1)
                im_x = tf.reshape(flat_x,
                                  shape=ps.concat([batch_shape, shape],
                                                  axis=0))
                return tf.matmul(im_x, kernel[..., tf.newaxis, :, :])
def _convolution_batch_nhwbc(
    x, kernel, rank, strides, padding, dilations, name):
  """Specialization of batch conv to NHWBC data format."""
  with tf.name_scope(name or 'conv2d_nhwbc'):
    # Prepare arguments.
    [
        _,  # filter shape
        rank,
        _,  # strides
        padding,
        dilations,
    ] = convolution_util.prepare_conv_args(1, rank, strides, padding, dilations)
    strides = prepare_strides(strides, rank + 2, arg_name='strides')

    dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
    x = tf.convert_to_tensor(x, dtype=dtype, name='x')
    kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel')

    # Step 1: Transpose and double flatten kernel.
    # kernel.shape = B + F + [c, c']. Eg: [b, fh, fw, c, c']
    kernel_shape = ps.shape(kernel)
    kernel_batch_shape, kernel_event_shape = ps.split(
        kernel_shape,
        num_or_size_splits=[-1, rank + 2])
    kernel_batch_size = ps.reduce_prod(kernel_batch_shape)
    kernel_ndims = ps.rank(kernel)
    kernel_batch_ndims = kernel_ndims - rank - 2
    perm = ps.concat([
        ps.range(kernel_batch_ndims, kernel_batch_ndims + rank),
        ps.range(0, kernel_batch_ndims),
        ps.range(kernel_batch_ndims + rank, kernel_ndims),
    ], axis=0)  # Eg, [1, 2, 0, 3, 4]
    kernel = tf.transpose(kernel, perm=perm)  # F + B + [c, c']
    kernel = tf.reshape(
        kernel,
        shape=ps.concat([
            kernel_event_shape[:rank],
            [kernel_batch_size * kernel_event_shape[-2],
             kernel_event_shape[-1]],
        ], axis=0))  # F + [bc, c']

    # Step 2: Double flatten x.
    # x.shape = N + D + B + [c]
    x_shape = ps.shape(x)
    [
        x_sample_shape,
        x_rank_shape,
        x_batch_shape,
        x_channel_shape,
    ] = ps.split(
        x_shape,
        num_or_size_splits=[-1, rank, kernel_batch_ndims, 1])
    x = tf.reshape(
        x,  # N + D + B + [c]
        shape=ps.concat([
            [ps.reduce_prod(x_sample_shape)],
            x_rank_shape,
            [ps.reduce_prod(x_batch_shape) *
             ps.reduce_prod(x_channel_shape)],
        ], axis=0))  # [n] + D + [bc]

    # Step 3: Apply convolution.
    y = tf.nn.depthwise_conv2d(
        x, kernel,
        strides=strides,
        padding=padding,
        data_format='NHWC',
        dilations=dilations)
    #  SAME: y.shape = [n, h,      w,      bcc']
    # VALID: y.shape = [n, h-fh+1, w-fw+1, bcc']

    # Step 4: Reshape/reduce for output.
    y_shape = ps.shape(y)
    y = tf.reshape(
        y,
        shape=ps.concat([
            x_sample_shape,
            y_shape[1:-1],
            kernel_batch_shape,
            kernel_event_shape[-2:],
        ], axis=0))  # N + D' + B + [c, c']
    y = tf.reduce_sum(y, axis=-2)  # N + D' + B + [c']

    return y