Esempio n. 1
0
    def _sample_n(self, n, seed=None):
        # Generate samples using:
        # mu + sigma* sgn(U-0.5)* sqrt(X^2 + Y^2 + Z^2) U~Unif; X,Y,Z ~N(0,1)
        normal_seed, rademacher_seed = samplers.split_seed(
            seed, salt='DoublesidedMaxwell')

        loc = tf.convert_to_tensor(self.loc)
        scale = tf.convert_to_tensor(self.scale)
        shape = prefer_static.pad(self._batch_shape_tensor(loc=loc,
                                                           scale=scale),
                                  paddings=[[1, 0]],
                                  constant_values=n)

        # Generate one-sided Maxwell variables by using 3 Gaussian variates
        norm_rvs = samplers.normal(shape=prefer_static.pad(shape,
                                                           paddings=[[0, 1]],
                                                           constant_values=3),
                                   dtype=self.dtype,
                                   seed=normal_seed)
        maxwell_rvs = tf.norm(norm_rvs, axis=-1)

        # Generate random signs for the symmetric variates.
        random_sign = tfp_math.random_rademacher(shape, seed=rademacher_seed)
        sampled = random_sign * maxwell_rvs * scale + loc
        return sampled
Esempio n. 2
0
  def test_batching(self, input_batch_shape, kernel_batch_shape):
    input_shape = (12, 12, 2)
    filter_shape = (3, 3)
    channels_out = 4
    strides = 2
    dilations = (1, 1)
    padding = 'SAME'

    x, k = _make_input_and_kernel(
        self.make_input,
        input_batch_shape=input_batch_shape,
        input_shape=input_shape,
        kernel_batch_shape=kernel_batch_shape,
        filter_shape=filter_shape,
        channels_out=channels_out,
        dtype=self.dtype)

    conv_fn = self.make_conv_fn(filter_shape, strides, padding, dilations)
    y_batched = conv_fn(x, k)

    broadcast_batch_shape = ps.broadcast_shape(
        input_batch_shape, kernel_batch_shape)
    broadcasted_input = tf.broadcast_to(
        x, shape=ps.concat([broadcast_batch_shape, input_shape], axis=0))
    broadcasted_kernel = tf.broadcast_to(
        k, shape=ps.concat([broadcast_batch_shape, ps.shape(k)[-2:]], axis=0))

    flat_y = tf.reshape(
        y_batched,
        shape=ps.pad(
            ps.shape(y_batched)[-3:], paddings=[[1, 0]], constant_values=-1))
    flat_x = tf.reshape(
        broadcasted_input,
        shape=ps.pad(input_shape, paddings=[[1, 0]], constant_values=-1))
    flat_tf_kernel = tf.einsum(
        '...ij->...ji',
        tf.reshape(
            broadcasted_kernel,
            shape=ps.concat(
                [(-1,), filter_shape, (input_shape[-1], channels_out)],
                axis=0)))

    rank = 2
    output_shape, strides_ = convolution_util._get_output_shape(
        rank=rank, strides=(strides,) * rank, padding=padding,
        dilations=dilations, input_shape=input_shape, output_size=channels_out,
        filter_shape=filter_shape)

    y_expected = tf.vectorized_map(
        lambda args: tf.nn.conv2d_transpose(  # pylint: disable=g-long-lambda
            args[0][tf.newaxis],
            args[1],
            output_shape=ps.concat([[1], output_shape], axis=0),
            strides=strides_,
            padding=padding),
        elems=(flat_x, flat_tf_kernel))

    [y_actual_, y_expected_] = self.evaluate(
        [flat_y, tf.squeeze(y_expected, axis=1)])
    self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0)
Esempio n. 3
0
 def _prepare_for_underlying(self, x):
     batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                      self.distribution.batch_shape)
     extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
     event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor,
                                      self.distribution.event_shape)
     ndims = ps.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=ps.pad(ps.shape(x),
                                 paddings=[[ps.maximum(0, -d), 0]],
                                 constant_values=1))
     ndims = ps.rank(x)
     sample_ndims = ps.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = ps.range(0, sample_ndims)
     batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims)
     extra_sample_dims = ps.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims,
                           ndims)
     perm = ps.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(x, perm=perm)
     return x, (sample_ndims, extra_sample_ndims, batch_ndims)
Esempio n. 4
0
def _find_bins(x, edges, axis, dtype=tf.int64, name=None):
    """Like `tfp.stats.find_bins` but correctly handles quantiles axis arg."""
    with tf.name_scope(name or 'find_bins'):
        # We can't do this:
        #   return tf.cast(quantiles_lib.find_bins(x, edges=edges), dtype=tf.int64)
        # because it doesn't seem to correctly handle axis!=-1. This is a bug in TFP
        # and should be fixed. Furthermore, the following is probably more efficient
        # than tfp.stats..find_bins anyway.
        num_buckets = ps.size0(edges) - 1
        # First, we need to have `keepdims=True` semantics for edges.
        axis = axis % ps.rank(x)
        edges = tf.expand_dims(edges, axis + 1)
        # We now find the bucket which is is the "first larger", then subtract one
        # to get the bucket which is the "last smaller". Care must be taken for the
        # max element.
        pred = x < edges
        # The following is equivalent to:
        #    tf.argmin(tf.cast(~pred, dtype), axis=0))
        # yet gives the same implementation across TPU/GPU/CPU.
        # As a bonus, we can also leverage the `sorted=True` behavior.
        _, bucket_larger = tf.math.top_k(tf.cast(
            tf.transpose(pred,
                         ps.pad(ps.range(1, ps.rank(pred)),
                                paddings=[[0, 1]])), dtype),
                                         k=1,
                                         sorted=True)
        bucket_larger = bucket_larger[..., 0]

        bucket_larger = tf.where(
            pred[-1],  # == ~tf.math.reduce_all(pred, axis=0)
            tf.cast(bucket_larger, dtype),
            tf.cast(num_buckets, dtype))
        return bucket_larger - 1
Esempio n. 5
0
def _get_output_shape(rank,
                      strides,
                      padding,
                      dilations,
                      input_shape,
                      output_size,
                      filter_shape,
                      output_padding=None):
    """Compute the `output_shape` and `strides` arg used by `conv_transpose`."""
    if output_padding is None:
        output_padding = (None, ) * rank
    else:
        output_padding = utils.prepare_tuple_argument(
            output_padding, n=rank, arg_name='output_padding')
        for stride, out_pad in zip(strides, output_padding):
            if out_pad >= stride:
                raise ValueError('Stride {} must be greater than output '
                                 'padding {}.'.format(strides, output_padding))
    event_shape = []
    for i in range(-rank, 0):
        event_shape.append(
            _deconv_output_length(input_shape[i - 1],
                                  filter_size=filter_shape[i],
                                  padding=padding,
                                  output_padding=output_padding[i],
                                  stride=strides[i],
                                  dilation=dilations[i]))
    event_shape.append(output_size)
    batch_shape = input_shape[:-rank - 1]
    output_shape = ps.concat([batch_shape, event_shape], axis=0)
    strides = ps.pad(strides, paddings=[[1, 1]], constant_values=1)
    return output_shape, strides
Esempio n. 6
0
def expand_dims(x, axis, name=None):
    """Like `tf.expand_dims` but accepts a vector of axes to expand."""
    with tf.name_scope(name or 'expand_dims'):
        x = tf.convert_to_tensor(x, name='x')
        axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis')
        nx = prefer_static.rank(x)
        na = prefer_static.size(axis)
        is_neg_axis = axis < 0
        k = prefer_static.reduce_sum(
            prefer_static.cast(is_neg_axis, axis.dtype))
        axis = prefer_static.where(is_neg_axis, axis + nx, axis)
        axis = prefer_static.sort(axis)
        axis_neg, axis_pos = prefer_static.split(axis, [k, -1])
        idx = prefer_static.argsort(prefer_static.concat([
            axis_pos,
            prefer_static.range(nx),
            axis_neg,
        ],
                                                         axis=0),
                                    stable=True)
        shape = prefer_static.pad(prefer_static.shape(x),
                                  paddings=[[na - k, k]],
                                  constant_values=1)
        shape = prefer_static.gather(shape, idx)
        return tf.reshape(x, shape)
Esempio n. 7
0
def _rightmost_expand_to_rank(tensor, new_rank):
    """Expands `tensor`'s rank by `new_rank - tensor.rank` rightmost dims."""
    return tf.reshape(
        tensor,
        shape=prefer_static.pad(
            prefer_static.shape(tensor),
            paddings=[[0, max(0, new_rank - prefer_static.rank(tensor))]],
            constant_values=1))
Esempio n. 8
0
def left_justified_expand_dims_to(x, rank, name=None):
    """Right pads `x` with `rank - rank(x)` ones."""
    with tf.name_scope(name or 'left_justified_expand_dims_to'):
        rank = tf.convert_to_tensor(rank, dtype=tf.int32)
        expand_ndims = prefer_static.maximum(rank - prefer_static.rank(x), 0)
        expand_shape = prefer_static.pad(prefer_static.shape(x),
                                         paddings=[[0, expand_ndims]],
                                         constant_values=1)
        return prefer_static.reshape(x, expand_shape)
Esempio n. 9
0
def im2row(x, block_shape, slice_step=(1, 1), padding='VALID', name=None):
    """Rearrange image blocks into rows.

  This function can be used to implement 2D convolution as a `matmul`, e.g.,

  `tf.nn.conv2d(x, k) = tf.matmul(
      tf.experimental.nn.util.im2row(x), tf.reshape(k, shape=[-1, out_size]))`.

  Args:
    x: Rank 3 (or more) Tensor representing 2D images.
    block_shape: Length-2 vector representing the block or "filter" shape.
    slice_step: Length-2 vector specifying the convolution stride length.
      Default value: `(1, 1)`.
    padding: One of `'VALID'` or `'SAME'` (case insensitive).
      Default value: `'VALID'`.
    name: Python `str` used to describe ops created by this function.
      Default value: `None` (i.e., `'im2col'`).

  Returns:
    im2row_x: batch of matrices representing subblock copies of `x`.
      Same batch shape as `x` but with rightmost shape:
      `batch_shape + [oh * ow, block_shape[0] * block_shape[1] * channels]`,
      where `oh = (h - block_shape[0] + 1) // slice_step[0]` and
      `ow = (w - block_shape[1] + 1) // slice_step[1]` when `padding = 'VALID'`
      and `oh = h` and `ow = w` when `padding = 'SAME'`.
    shape: shape `Tensor` equivalent to:
      `batch_shape + [oh, ow, block_shape[0] * block_shape[1] * channels]` where
      `oh, ow` are defined as above.
  """
    with tf.name_scope(name or 'im2row'):
        padding = _validate_padding(padding)
        if padding == 'VALID':
            pass  # Do nothing.
        elif padding == 'SAME':
            raise NotImplementedError(
                'Argument padding="SAME" not implemented.')
            # TODO(jvdillon): See if the following works:
            # fh, fw = block_shape
            # o = 1 if data_format == 'NHWC' else 0
            # n = ps.maximum(0, ps.rank(x) - 3)
            # paddings = ps.pad(
            #     [[0, fh - 1], [0, fw - 1]],
            #     paddings=[[n + 1 - o, o], [0, 0]],
            #     constant_values=0)
            # x = tf.pad(x, paddings=paddings, constant_values=0)
            # padding = 'VALID'
        else:
            assert False  # Can't be here.
        x_shape = ps.shape(x)
        idx, s = im2row_index(x_shape,
                              block_shape=block_shape,
                              slice_step=slice_step)
        flat_shape = ps.pad(x_shape[:-3],
                            paddings=[[0, 1]],
                            constant_values=-1)
        x = tf.gather(tf.reshape(x, flat_shape), idx, axis=-1)  # == np.take
        return tf.reshape(x, s)
Esempio n. 10
0
def pad_shape_with_ones(x, ndims, start=-1):
  """Maybe add `ndims` ones to `x.shape` starting at `start`.

  If `ndims` is zero, this is a no-op; otherwise, we will create and return a
  new `Tensor` whose shape is that of `x` with `ndims` ones concatenated on the
  right side. If the shape of `x` is known statically, the shape of the return
  value will be as well.

  Args:
    x: The `Tensor` we'll return a reshaping of.
    ndims: Python `integer` number of ones to pad onto `x.shape`.
    start: Python `integer` specifying where to start padding with ones. Must
      be a negative integer. For instance, a value of `-1` means to pad at the
      end of the shape. Default value: `-1`.
  Returns:
    If `ndims` is zero, `x`; otherwise, a `Tensor` whose shape is that of `x`
    with `ndims` ones concatenated on the right side. If possible, returns a
    `Tensor` whose shape is known statically.
  Raises:
    ValueError: if `ndims` is not a Python `integer` greater than or equal to
    zero.
  """
  if not (isinstance(ndims, int) and ndims >= 0):
    raise ValueError(
        '`ndims` must be a Python `integer` greater than zero. Got: {}'
        .format(ndims))
  if not (isinstance(start, int) and start <= -1):
    raise ValueError(
        '`start` must be a Python `integer` less than zero. Got: {}'
        .format(start))
  if ndims == 0:
    return x
  x = tf.convert_to_tensor(value=x)
  original_shape = x.shape
  rank = ps.rank(x)
  first_shape = ps.shape(x)[:rank + start + 1]
  second_shape = ps.shape(x)[rank + start + 1:]
  new_shape = ps.pad(first_shape, paddings=[[0, ndims]], constant_values=1)
  new_shape = ps.concat([new_shape, second_shape], axis=0)
  x = tf.reshape(x, new_shape)
  if start == -1:
    tensorshape_util.set_shape(
        x, tensorshape_util.concatenate(original_shape, [1] * ndims))
  elif tensorshape_util.rank(original_shape) is not None:
    original_ndims = tensorshape_util.rank(original_shape)
    new_shape = tensorshape_util.concatenate(
        original_shape[:original_ndims + start + 1],
        tensorshape_util.concatenate(
            [1] * ndims,
            original_shape[original_ndims + start + 1:]))
    tensorshape_util.set_shape(x, new_shape)
  return x
Esempio n. 11
0
def batchify_op(op, op_min_input_ndims, x, *other_op_args):
    """Reshape `op` input `x` to be a vec of `op_min_input_ndims`-rank tensors."""
    if x.shape.rank == op_min_input_ndims + 1:
        # Input is already a vector of `op_min_input_ndims`-rank tensors.
        return op(x, *other_op_args)
    batch_shape, op_shape = ps.split(
        ps.shape(x), num_or_size_splits=[-1, op_min_input_ndims])
    flat_shape = ps.pad(op_shape, paddings=[[1, 0]], constant_values=-1)
    y = tf.reshape(x, flat_shape)
    y = op(y, *other_op_args)
    unflat_shape = ps.concat([
        batch_shape,
        ps.shape(y)[1:],
    ], axis=0)
    y = tf.reshape(y, unflat_shape)
    return y
Esempio n. 12
0
    def _inverse(self, y):
        ndims = ps.rank(y)
        shifted_y = ps.pad(
            ps.slice(
                y, ps.zeros(ndims, dtype=tf.int32),
                ps.shape(y) -
                ps.one_hot(ndims + self.axis, ndims, dtype=tf.int32)
            ),  # Remove the last entry of y in the chosen dimension.
            paddings=ps.one_hot(
                ps.one_hot(ndims + self.axis, ndims, on_value=0, off_value=-1),
                2,
                dtype=tf.int32
            )  # Insert zeros at the beginning of the chosen dimension.
        )

        return y - shifted_y
Esempio n. 13
0
 def _prepare_for_underlying(self, x):
     batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                      self.distribution.batch_shape)
     extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
     event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor,
                                      self.distribution.event_shape)
     ndims = ps.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=ps.pad(ps.shape(x),
                                 paddings=[[ps.maximum(0, -d), 0]],
                                 constant_values=1))
     sample_ndims = ps.maximum(0, d)
     x = tf.transpose(x,
                      perm=ps.invert_permutation(
                          self._sampling_permutation(sample_ndims)))
     return x, (sample_ndims, extra_sample_ndims, batch_ndims)
Esempio n. 14
0
def _compute_fans_from_shape(shape, batch_ndims=0):
  """Extracts `fan_in, fan_out` from specified shape `Tensor`."""
  # Ensure shape is a vector of length >=2.
  num_pad = prefer_static.maximum(0, 2 - prefer_static.size(shape))
  shape = prefer_static.pad(
      shape, paddings=[[0, num_pad]], constant_values=1)
  (
      batch_shape,  # pylint: disable=unused-variable
      extra_shape,
      fan_in,
      fan_out,
  ) = prefer_static.split(shape, [batch_ndims, -1, 1, 1])
  # The following logic is primarily intended for convolutional layers which
  # have spatial semantics in addition to input/output channels.
  receptive_field_size = prefer_static.reduce_prod(extra_shape)
  fan_in = fan_in[0] * receptive_field_size
  fan_out = fan_out[0] * receptive_field_size
  return fan_in, fan_out
Esempio n. 15
0
  def expand_ends(x, broadcast=False):
    """Expand x so it can bcast w/ tensors of output shape."""
    # Assume out_shape = A + x.shape + B, and rank(A) = axis.
    # Expand with singletons with same rank as A, B.
    expanded_shape = ps.pad(
        tensor=ps.shape(x),
        paddings=[[axis, ps.size(y_ref_shape_right)]],
        constant_values=1)
    x_expanded = tf.reshape(x, expanded_shape)

    if broadcast:
      out_shape = ps.concat((
          y_ref_shape_left,
          ps.shape(x),
          y_ref_shape_right,
      ),
                            axis=0)
      x_expanded = _broadcast_with(x_expanded, out_shape)
    return x_expanded
Esempio n. 16
0
 def _log_prob(self, x, **kwargs):
     batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                      self.distribution.batch_shape)
     extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
     event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor,
                                      self.distribution.event_shape)
     ndims = ps.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=ps.pad(ps.shape(x),
                                 paddings=[[ps.maximum(0, -d), 0]],
                                 constant_values=1))
     ndims = ps.rank(x)
     sample_ndims = ps.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = ps.range(0, sample_ndims)
     batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims)
     extra_sample_dims = ps.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims,
                           ndims)
     perm = ps.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(a=x, perm=perm)
     # (3) Compute x's log_prob.
     lp = self.distribution.log_prob(x, **kwargs)
     # (4) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has
     #     full sample shape in the sample axes, before we reduce.
     bcast_lp_shape = ps.broadcast_shape(
         ps.shape(lp),
         ps.concat([
             ps.ones([sample_ndims], tf.int32),
             ps.reshape(self.sample_shape, shape=[-1]),
             ps.ones([batch_ndims], tf.int32)
         ],
                   axis=0))
     lp = tf.broadcast_to(lp, bcast_lp_shape)
     # (5) Make the final reduction in x.
     axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims)
     return tf.reduce_sum(lp, axis=axis)
Esempio n. 17
0
def _sub_diag(nonmatrix):
    """Get the first sub-diagonal of a shape [N, N, ...] 'non matrix'."""
    with tf.name_scope('sub_matrix'):
        # TODO(b/143702351) Once array_ops.matrix_diag_part_v3 is ready and exposed,
        # replace the call to matrix_diag_part_v2 below with tf.linalg.matrix_diag.
        # We can also stop special casing for matrix_dim < 2 at that point.
        # Until then, OpError raised for 1x1 matricies without static shape.
        # In fact, non-static shape breaks matrix_diag_part_v2, so we must raise
        # this message now.
        # See http://b/138403336 for the TF issue tracker.
        if not tensorshape_util.is_fully_defined(nonmatrix.shape[:2]):
            raise ValueError(
                '`inverse_temperatures did not have statically defined shape, '
                'which breaks tracking of is_swap_{proposed,accepted}.  '
                'Please provide an inverse_temperatures with statically known shape.'
            )

        # The sub-matrix of a 1x1 matrix is not defined (throws exception), so in
        # this special case return an empty matrix.
        # TODO(b/143702351) Remove this special case handling once
        # matrix_diag_part_v3 is ready.
        matrix_dim = ps.size0(nonmatrix)
        if matrix_dim is not None and matrix_dim < 2:
            # Shape is [..., 0], so returned tensor is empty, thus contains no
            # values...and therefore the fact that we use 'ones' doesn't matter.
            shape = ps.pad(ps.shape(nonmatrix)[2:],
                           paddings=[[0, 1]],
                           constant_values=0)
            matrix_sub_diag = tf.cast(tf.ones(shape), nonmatrix.dtype)

        else:
            # Get first sub-diagonal.  `padding_value` is not used (since matrix is
            # square), but is required for the API since this is raw gen_array_ops.
            matrix_sub_diag = tf.raw_ops.MatrixDiagPartV2(
                input=distribution_util.rotate_transpose(nonmatrix, shift=-2),
                k=ps.convert_to_shape_tensor(-1, dtype=tf.int32),
                padding_value=tf.cast(0.0, dtype=nonmatrix.dtype))

        return distribution_util.rotate_transpose(matrix_sub_diag, shift=1)
def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed):
    """Returns a uniformly random `Tensor` of "correlation-like" matrices.

  A "correlation-like" matrix is a symmetric square matrix with all entries
  between -1 and 1 (inclusive) and 1s on the main diagonal.  Of these,
  the ones that are positive semi-definite are exactly the correlation
  matrices.

  Args:
    num_rows: Python `int` dimension of the correlation-like matrices.
    batch_shape: `Tensor` or Python `tuple` of `int` shape of the
      batch to return.
    dtype: `dtype` of the `Tensor` to return.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

  Returns:
    matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]`
      and dtype `dtype`.  Each entry is in [-1, 1], and each matrix
      along the bottom two dimensions is symmetric and has 1s on the
      main diagonal.
  """
    num_entries = num_rows * (num_rows + 1) // 2
    ones = tf.ones(shape=[num_entries], dtype=dtype)
    # It seems wasteful to generate random values for the diagonal since
    # I am going to throw them away, but `fill_triangular` fills the
    # diagonal, so I probably need them.
    # It's not impossible that it would be more efficient to just fill
    # the whole matrix with random values instead of messing with
    # `fill_triangular`.  Then would need to filter almost half out with
    # `matrix_band_part`.
    unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed)
    tril = fill_triangular(unifs)
    symmetric = tril + tf.linalg.matrix_transpose(tril)
    diagonal_ones = tf.ones(prefer_static.pad(batch_shape,
                                              paddings=[[0, 1]],
                                              constant_values=num_rows),
                            dtype=dtype)
    return tf.linalg.set_diag(symmetric, diagonal_ones)
Esempio n. 19
0
                def loop_body(i, outputs):
                    subkernel_ind = kernels_ind.read(i)
                    fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2)
                    eh = ex_h + fh_ - 1
                    ew = ex_w + fw_ - 1

                    subkernel_ind = ps.reshape(ps.reshape(
                        subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] +
                                               ps.range(c_in),
                                               shape=[-1])

                    k = tf.gather(kernel, subkernel_ind, axis=-2)
                    ind, shape = im2row_index([eh, ew, c_in],
                                              block_shape=(fh_, fw_),
                                              slice_step=(1, 1),
                                              dilations=dilations)
                    x_i = x_pad[..., :eh, :ew, :]
                    x_i_shape = ps.shape(x_i)
                    flat_shape = ps.pad(x_i_shape[:-3],
                                        paddings=[[0, 1]],
                                        constant_values=-1)
                    flat_x = tf.reshape(x_i, flat_shape)
                    x_ = tf.gather(flat_x, ind, axis=-1)
                    im_x = tf.reshape(
                        x_, ps.concat([x_i_shape[:-3], shape], axis=0))
                    outputs = outputs.write(
                        i,
                        tf.matmul(
                            im_x,
                            tf.reshape(
                                k,
                                ps.concat([
                                    kernel_batch, [1, fh_ * fw_ * c_in, c_out]
                                ],
                                          axis=0))))
                    return i + 1, outputs
Esempio n. 20
0
 def _log_prob(self, x, **kwargs):
     batch_ndims = prefer_static.rank_from_shape(
         self.distribution.batch_shape_tensor,
         self.distribution.batch_shape)
     extra_sample_ndims = prefer_static.rank_from_shape(self.sample_shape)
     event_ndims = prefer_static.rank_from_shape(
         self.distribution.event_shape_tensor,
         self.distribution.event_shape)
     ndims = prefer_static.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=prefer_static.pad(
                        prefer_static.shape(x),
                        paddings=[[prefer_static.maximum(0, -d), 0]],
                        constant_values=1))
     ndims = prefer_static.rank(x)
     sample_ndims = prefer_static.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = prefer_static.range(0, sample_ndims)
     batch_dims = prefer_static.range(sample_ndims,
                                      sample_ndims + batch_ndims)
     extra_sample_dims = prefer_static.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = prefer_static.range(
         sample_ndims + batch_ndims + extra_sample_ndims, ndims)
     perm = prefer_static.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(a=x, perm=perm)
     # (3) Compute x's log_prob.
     lp = self.distribution.log_prob(x, **kwargs)
     # (4) Make the final reduction in x.
     axis = prefer_static.range(sample_ndims,
                                sample_ndims + extra_sample_ndims)
     return tf.reduce_sum(lp, axis=axis)
Esempio n. 21
0
 def _transpose_around_bijector_fn(self,
                                   bijector_fn,
                                   arg,
                                   src_event_ndims,
                                   dest_event_ndims=None,
                                   fn_reduces_event=False,
                                   **kwargs):
   # This function moves the axes corresponding to `self.sample_shape` to the
   # left of the batch shape, then applies `bijector_fn`, then moves the axes
   # corresponding to `self.sample_shape` back to the event part of the shape.
   #
   # `src_event_ndims` and `dest_event_ndims` indicate the expected event rank
   # (omitting `self.sample_shape`) before and after applying `bijector_fn`.
   #
   # This function arose because forward and inverse ended up being quite
   # similar. It was then only a small generalization to also support {F/I}LDJ.
   batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                    self.distribution.batch_shape)
   extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
   arg_ndims = ps.rank(arg)
   # (1) Expand arg's dims.
   d = arg_ndims - batch_ndims - extra_sample_ndims - src_event_ndims
   arg = tf.reshape(
       arg,
       shape=ps.pad(
           ps.shape(arg),
           paddings=[[ps.maximum(0, -d), 0]],
           constant_values=1))
   arg_ndims = ps.rank(arg)
   sample_ndims = ps.maximum(0, d)
   # (2) Transpose arg's dims.
   sample_dims = ps.range(0, sample_ndims)
   batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims)
   extra_sample_dims = ps.range(
       sample_ndims + batch_ndims,
       sample_ndims + batch_ndims + extra_sample_ndims)
   event_dims = ps.range(
       sample_ndims + batch_ndims + extra_sample_ndims,
       arg_ndims)
   perm = ps.concat(
       [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
   arg = tf.transpose(arg, perm=perm)
   # (3) Apply underlying bijector.
   result = bijector_fn(arg, **kwargs)
   # (4) Transpose sample_shape from the sample to the event shape.
   result_ndims = ps.rank(result)
   if fn_reduces_event:
     dest_event_ndims = 0
   d = result_ndims - batch_ndims - extra_sample_ndims - dest_event_ndims
   if fn_reduces_event:
     # In some cases, fn may reduce event too far, i.e. ildj may return a
     # scalar `0.`, which won't work with the transpose we do below.
     result = tf.reshape(
         result,
         shape=ps.pad(
             ps.shape(result),
             paddings=[[ps.maximum(0, -d), 0]],
             constant_values=1))
     result_ndims = ps.rank(result)
   sample_ndims = ps.maximum(0, d)
   sample_dims = ps.range(0, sample_ndims)
   extra_sample_dims = ps.range(sample_ndims,
                                sample_ndims + extra_sample_ndims)
   batch_dims = ps.range(sample_ndims + extra_sample_ndims,
                         sample_ndims + extra_sample_ndims + batch_ndims)
   event_dims = ps.range(sample_ndims + extra_sample_ndims + batch_ndims,
                         result_ndims)
   perm = ps.concat(
       [sample_dims, batch_dims, extra_sample_dims, event_dims], axis=0)
   return tf.transpose(result, perm=perm)
Esempio n. 22
0
 def test_num_paddings_dynamic(self):
     n = tf1.placeholder_with_default(2, shape=None)
     x = ps.pad([2, 3], paddings=[[0, n]], constant_values=1)
     if not ps.is_numpy(x):
         x = self.evaluate(x)
     self.assertAllEqual([2, 3, 1, 1], x)
Esempio n. 23
0
 def test_num_paddings_static(self):
     n = 2
     x = ps.pad([2, 3], paddings=[[0, n]], constant_values=1)
     self.assertAllEqual([2, 3, 1, 1], x)