示例#1
0
        def loop_body(i_, event_ind):
            i = i_ // strides
            j = i_ % strides

            i_ind = ps.range(i * fw,
                             ps.maximum(i, fh) * fw,
                             delta=strides * fw,
                             dtype=dtype)
            j_ind = ps.range(j, ps.maximum(j, fw), delta=strides, dtype=dtype)

            nc = cartesian_add([i_ind, j_ind])
            ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0])

            k = ps.reshape(cartesian_add([
                ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype),
                ps.range(ps.shape(nc)[1], dtype=dtype)
            ]),
                           shape=[-1])
            last_j = strides - (fw - j - 1) % strides - 1
            last_i = strides - (fh - i - 1) % strides - 1
            kernel_ind = ps.stack(
                [k, ps.ones_like(k) * last_i * strides + last_j], axis=1)
            event_ind = ps.tensor_scatter_nd_update(event_ind, ind[...,
                                                                   tf.newaxis],
                                                    kernel_ind)

            return i_ + 1, event_ind
示例#2
0
def _log_average_probs_process_args(logits, validate_args, sample_axis,
                                    event_axis):
    """Processes args for `log_average_probs`."""
    rank = ps.rank(logits)
    if sample_axis is None or validate_args:
        event_axis = ps.reshape(ps.non_negative_axis(event_axis, rank),
                                shape=[-1])
    if sample_axis is None:
        sample_axis = ps.setdiff1d(ps.range(rank), event_axis)
    elif validate_args:
        sample_axis = ps.reshape(ps.non_negative_axis(sample_axis, rank),
                                 shape=[-1])
    return sample_axis, event_axis
示例#3
0
 def _forward_event_shape_tensor(self, input_shape, is_inverse=False):
   ndims = ps.size(input_shape)
   indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1])
   extra_sizes = ps.reduce_sum(self.paddings, axis=-1)
   update_fn = (ps.tensor_scatter_nd_sub if is_inverse else
                ps.tensor_scatter_nd_add)
   return update_fn(ps.identity(input_shape), indices, extra_sizes)
示例#4
0
 def _sample_n(self, n, seed, **kwargs):
     sample_shape = ps.reshape(self.sample_shape, shape=[-1])
     x = self.distribution.sample(ps.concat([[n], sample_shape], axis=0),
                                  seed=seed,
                                  **kwargs)
     return tf.transpose(a=x,
                         perm=self._sampling_permutation(sample_ndims=1))
示例#5
0
 def _sample_n(self, n, seed, **kwargs):
     sample_shape = prefer_static.reshape(self.sample_shape, shape=[-1])
     fake_sample_ndims = prefer_static.rank_from_shape(sample_shape)
     event_ndims = prefer_static.rank_from_shape(
         self.distribution.event_shape_tensor,
         self.distribution.event_shape)
     batch_ndims = prefer_static.rank_from_shape(
         self.distribution.batch_shape_tensor,
         self.distribution.batch_shape)
     perm = prefer_static.concat([
         [0],
         prefer_static.range(1 + fake_sample_ndims,
                             1 + fake_sample_ndims + batch_ndims,
                             dtype=tf.int32),
         prefer_static.range(1, 1 + fake_sample_ndims, dtype=tf.int32),
         prefer_static.range(
             1 + fake_sample_ndims + batch_ndims,
             1 + fake_sample_ndims + batch_ndims + event_ndims,
             dtype=tf.int32),
     ],
                                 axis=0)
     x = self.distribution.sample(prefer_static.concat([[n], sample_shape],
                                                       axis=0),
                                  seed=seed,
                                  **kwargs)
     return tf.transpose(a=x, perm=perm)
示例#6
0
 def _inverse(self, y):
     ndims = prefer_static.rank(y)
     indices = prefer_static.reshape(prefer_static.add(self.axis, ndims),
                                     shape=[-1, 1])
     num_left, num_right = prefer_static.unstack(self.paddings,
                                                 num=2,
                                                 axis=-1)
     x = tf.slice(y,
                  begin=prefer_static.tensor_scatter_nd_update(
                      prefer_static.zeros(ndims, dtype=tf.int32), indices,
                      num_left),
                  size=prefer_static.tensor_scatter_nd_sub(
                      prefer_static.shape(y), indices,
                      num_left + num_right))
     if not self.validate_args:
         return x
     assertions = [
         assert_util.assert_equal(
             self._forward(x),
             y,
             message=('Argument `y` to `inverse` was not padded with '
                      '`constant_values`.')),
     ]
     with tf.control_dependencies(assertions):
         return tf.identity(x)
示例#7
0
    def _transpose_and_reshape_result(self, x, sample_shape, event_shape=None):
        if event_shape is None:
            event_shape = self.event_shape_tensor()

        batch_shape = self.batch_shape_tensor()
        batch_rank = ps.rank_from_shape(batch_shape)

        underlying_batch_shape = self.distribution.batch_shape_tensor()
        underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape)

        # Continuing the example from `_augment_sample_shape`, suppose we have:
        #   - sample shape of `[n]`,
        #   - underlying distribution batch shape of `[2, 1]`,
        #   - final broadcast batch shape of `[4, 2, 3]`.
        # and have drawn an `x` of shape `[n, 12, 2, 1] + event_shape`, which we
        # ultimately want to have shape `[n, 4, 2, 3] + event_shape`.

        # First, we reshape to expand out the batch elements:
        # `shape_with_doubled_batch == [n] + [4, 1, 3] + [1, 2, 1] + event_shape`,
        # where `[1, 2, 1]` is the fully-expanded underlying batch shape, and
        # `[4, 1, 3]` is the shape of the elements being added by broadcasting.
        underlying_bcast_shp = ps.concat([
            ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)],
                    dtype=underlying_batch_shape.dtype), underlying_batch_shape
        ],
                                         axis=0)
        is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp)
        x_with_doubled_batch = tf.reshape(
            x,
            ps.concat([
                sample_shape,
                ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp,
                event_shape
            ],
                      axis=0))

        # Next, construct the permutation that interleaves the batch dimensions,
        # resulting in samples with shape
        # `[n] + [4, 1] + [1, 2] + [3, 1] + event_shape`.
        # Note that each interleaved pair of batch dimensions contains exactly one
        # dim of size `1` and one of size `>= 1`.
        sample_ndims = ps.rank_from_shape(sample_shape)
        x_with_interleaved_batch = tf.transpose(
            x_with_doubled_batch,
            perm=ps.concat([
                ps.range(sample_ndims), sample_ndims + ps.reshape(
                    ps.stack([
                        ps.range(batch_rank),
                        ps.range(batch_rank) + batch_rank
                    ],
                             axis=-1), [-1]), sample_ndims + 2 * batch_rank +
                ps.range(ps.rank_from_shape(event_shape))
            ],
                           axis=0))

        # Final reshape to remove the spurious `1` dimensions.
        return tf.reshape(
            x_with_interleaved_batch,
            ps.concat([sample_shape, batch_shape, event_shape], axis=0))
示例#8
0
def left_justified_expand_dims_to(x, rank, name=None):
    """Right pads `x` with `rank - rank(x)` ones."""
    with tf.name_scope(name or 'left_justified_expand_dims_to'):
        rank = tf.convert_to_tensor(rank, dtype=tf.int32)
        expand_ndims = prefer_static.maximum(rank - prefer_static.rank(x), 0)
        expand_shape = prefer_static.pad(prefer_static.shape(x),
                                         paddings=[[0, expand_ndims]],
                                         constant_values=1)
        return prefer_static.reshape(x, expand_shape)
示例#9
0
def left_justified_expand_dims_to(x, rank, name=None):
  """Right pads `x` with `rank - rank(x)` ones."""
  with tf.name_scope(name or 'left_justified_expand_dims_to'):
    expand_ndims = ps.maximum(rank - ps.rank(x), 0)
    expand_shape = ps.concat(
        [ps.shape(x),
         ps.ones(shape=[expand_ndims], dtype=tf.int32)],
        axis=0)
    return ps.reshape(x, expand_shape)
 def _sample_direction_part(state_part, part_seed):
     state_part_shape = ps.shape(state_part)
     batch_shape = state_part_shape[:batch_rank]
     dimension = ps.reduce_prod(state_part_shape[batch_rank:])
     return ps.reshape(
         random_ops.spherical_uniform(shape=batch_shape,
                                      dimension=dimension,
                                      dtype=state_part.dtype,
                                      seed=part_seed), state_part_shape)
示例#11
0
  def __init__(
      self,
      input_size,
      output_size,
      # Weights
      init_kernel_fn=None,  # tfp.experimental.nn.initializers.glorot_uniform()
      init_bias_fn=None,    # tf.initializers.zeros()
      make_kernel_bias_fn=nn_util_lib.make_kernel_bias,
      dtype=tf.float32,
      batch_shape=(),
      # Misc
      activation_fn=None,
      name=None):
    """Constructs layer.

    Args:
      input_size: ...
      output_size: ...
      init_kernel_fn: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      init_bias_fn: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_kernel_bias_fn: ...
        Default value: `tfp.experimental.nn.util.make_kernel_bias`.
      dtype: ...
        Default value: `tf.float32`.
      batch_shape: ...
        Default value: `()`.
      activation_fn: ...
        Default value: `None`.
      name: ...
        Default value: `None` (i.e., `'Affine'`).
    """
    batch_shape = tf.constant(
        [], dtype=tf.int32) if batch_shape is None else prefer_static.cast(
            prefer_static.reshape(batch_shape, shape=[-1]), tf.int32)
    batch_ndims = prefer_static.size(batch_shape)
    kernel_shape = prefer_static.concat([
        batch_shape, [input_size, output_size]], axis=0)
    bias_shape = prefer_static.concat([batch_shape, [output_size]], axis=0)
    apply_kernel_fn = lambda x, k: tf.matmul(
        x[..., tf.newaxis, :], k)[..., 0, :]  # pylint-disable=long-lambda
    kernel, bias = make_kernel_bias_fn(
        kernel_shape, bias_shape,
        init_kernel_fn, init_bias_fn,
        batch_ndims, batch_ndims,
        dtype)
    self._make_kernel_bias_fn = make_kernel_bias_fn  # For tracking.
    super(Affine, self).__init__(
        kernel=kernel,
        bias=bias,
        apply_kernel_fn=apply_kernel_fn,
        activation_fn=activation_fn,
        dtype=dtype,
        name=name)
示例#12
0
 def _forward(self, x):
   ndims = ps.rank(x)
   indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1])
   return tf.pad(
       x,
       paddings=ps.tensor_scatter_nd_update(
           ps.zeros([ndims, 2], dtype=tf.int32),
           indices, self.paddings),
       mode=self.mode,
       constant_values=ps.cast(self.constant_values, dtype=x.dtype))
def _dummy_indices_like(indices):
    """Returns dummy indices ([0, 1, 2, ...]) with batch shape like `indices`."""
    indices_shape = ps.shape(indices)
    num_particles = indices_shape[0]
    return tf.broadcast_to(
        ps.reshape(
            ps.range(num_particles),
            ps.pad([num_particles],
                   paddings=[[0, ps.rank_from_shape(indices_shape) - 1]],
                   constant_values=1)), indices_shape)
示例#14
0
def left_justified_expand_dims_to(x, rank, name=None):
    """Right pads `x` with `rank - rank(x)` ones."""
    with tf.name_scope(name or 'left_justified_expand_dims_to'):
        rank = tf.convert_to_tensor(rank, dtype=tf.int32)
        expand_ndims = prefer_static.maximum(rank - prefer_static.rank(x), 0)
        expand_shape = prefer_static.concat([
            prefer_static.shape(x),
            prefer_static.ones(shape=[expand_ndims], dtype=tf.int32)
        ],
                                            axis=0)
        return prefer_static.reshape(x, expand_shape)
示例#15
0
def _dummy_indices_like(indices):
    """Returns dummy indices ([0, 1, 2, ...]) with batch shape like `indices`."""
    indices_shape = ps.shape(indices)
    num_particles = indices_shape[0]
    return tf.broadcast_to(
        ps.reshape(
            ps.range(num_particles),
            ps.concat([[num_particles],
                       ps.ones([ps.rank_from_shape(indices_shape) - 1],
                               dtype=np.int32)],
                      axis=0)), indices_shape)
示例#16
0
def _initialize(shape, dtype, batch_ndims, scale, mode, distribution,
                seed=None):
  """Samples a random `Tensor` per specified args."""
  if not dtype_util.is_floating(dtype):
    raise TypeError('Argument `dtype` must be float type (saw: "{}").'.format(
        dtype))
  shape = prefer_static.reshape(shape, shape=[-1])  # Ensure shape is vector.
  fan_in, fan_out = _compute_fans_from_shape(shape, batch_ndims)
  fans = _summarize_fans(fan_in, fan_out, mode, dtype)
  scale = prefer_static.cast(scale, dtype)
  return _sample_distribution(shape, scale / fans, distribution, seed, dtype)
示例#17
0
 def sample(self, sample_shape=(), seed=None, name='sample'):  # pylint: disable=unused-argument
     return tf.zeros(
         ps.concat(
             [
                 # sample_shape might be a scalar
                 ps.reshape(ps.convert_to_shape_tensor(
                     sample_shape, tf.int32),
                            shape=[-1]),
                 self.batch_shape_tensor(),
                 self.event_shape_tensor()
             ],
             axis=0))
示例#18
0
 def sample_shape(self):
   sample_shape = ps.reshape(self._sample_shape, shape=[-1])
   shard_axis_size = sample_shape[self.shard_axis]
   num_devices = self.num_devices
   if shard_axis_size % num_devices != 0:
     raise ValueError('Does not shard evenly.')
   shard_size = shard_axis_size // num_devices
   sample_shape = ps.concat([
       sample_shape[:self.shard_axis], [shard_size],
       sample_shape[self.shard_axis + 1:]
   ], axis=0)
   return sample_shape
示例#19
0
            def update_running_variance():
                diags = [
                    variance_part.variance()
                    for variance_part in variance_parts
                ]
                new_state_parts = tf.nest.flatten(new_state)
                new_variance_parts = []
                for variance_part, diag, state_part in zip(
                        variance_parts, diags, new_state_parts):
                    # Compute new variance for each variance part, accounting for partial
                    # batching of the variance calculation across chains (ie, some, all,
                    # or none of the chains may share the estimated mass matrix).
                    #
                    # For example, say
                    #
                    # state_part has shape       [2, 3, 4] + [5, 6]  (batch + event)
                    # variance_part has shape          [4] + [5, 6]
                    # log_prob has shape         [2, 3, 4]
                    #
                    # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass
                    # matrices, each being shared across a [2, 3]-batch of chains. Note
                    # this division is inferred from the shapes of the state part, the
                    # log_prob, and the user-provided initial running variances.
                    #
                    # Until RunningVariance supports rank > 1 chunking, we need to flatten
                    # the states that go into updating the variance estimates. In the
                    # above example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and
                    # fed to `RunningVariance.update(state_part, axis=0)`, recording
                    # 6 new observations in the running variance calculation.
                    # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and
                    # the resulting momentum distribution will have batch shape of
                    # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part.
                    state_rank = ps.rank(state_part)
                    variance_rank = ps.rank(diag)
                    num_reduce_dims = state_rank - variance_rank

                    state_part_shape = ps.shape(state_part)
                    # This reshape adds a 1 when reduce_dims==0, and collapses all the
                    # lead dimensions to a single one otherwise.
                    reshaped_state = ps.reshape(
                        state_part,
                        ps.concat([[
                            ps.reduce_prod(state_part_shape[:num_reduce_dims])
                        ], state_part_shape[num_reduce_dims:]],
                                  axis=0))

                    # The `axis=0` here removes the leading dimension we got from the
                    # reshape above, so the new_variance_parts have the correct shape
                    # again.
                    new_variance_parts.append(
                        variance_part.update(reshaped_state, axis=0))
                return new_variance_parts
示例#20
0
 def _finish_log_prob(self, lp, aux):
   (sample_ndims, extra_sample_ndims, batch_ndims) = aux
   # (1) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has
   #     full sample shape in the sample axes, before we reduce.
   bcast_lp_shape = ps.broadcast_shape(
       ps.shape(lp),
       ps.concat([ps.ones([sample_ndims], tf.int32),
                  ps.reshape(self.sample_shape, shape=[-1]),
                  ps.ones([batch_ndims], tf.int32)], axis=0))
   lp = tf.broadcast_to(lp, bcast_lp_shape)
   # (2) Make the final reduction.
   axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims)
   return self._sum_fn()(lp, axis=axis)
示例#21
0
 def _bcast_and_reduce_logdet(self, underlying_ldj):
   # Ensure ldj is fully broadcast in the sample dims, i.e. ensure ldj has
   # full sample shape in the sample axes, before we reduce.
   batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                    self.distribution.batch_shape)
   extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
   sample_ndims = ps.rank(underlying_ldj) - extra_sample_ndims - batch_ndims
   bcast_ldj_shape = ps.broadcast_shape(
       ps.shape(underlying_ldj),
       ps.concat([ps.ones([sample_ndims], tf.int32),
                  ps.ones([batch_ndims], tf.int32),
                  ps.reshape(self.sample_shape, shape=[-1])], axis=0))
   ldj = tf.broadcast_to(underlying_ldj, bcast_ldj_shape)
   return self._sum_fn(ldj, axis=-1 - ps.range(extra_sample_ndims))
示例#22
0
 def _sample_and_log_prob(self, sample_shape, seed, **kwargs):
   sample_ndims = ps.rank_from_shape(sample_shape)
   batch_ndims = ps.rank_from_shape(
       self.distribution.batch_shape_tensor,
       self.distribution.batch_shape)
   extra_sample_shape = ps.reshape(self.sample_shape, shape=[-1])
   extra_sample_ndims = ps.rank_from_shape(extra_sample_shape)
   x, lp = self.distribution.experimental_sample_and_log_prob(
       ps.concat([sample_shape, extra_sample_shape], axis=0), seed=seed,
       **kwargs)
   return (
       tf.transpose(x, perm=self._sampling_permutation(sample_ndims)),
       self._finish_log_prob(
           lp, aux=(sample_ndims, extra_sample_ndims, batch_ndims)))
示例#23
0
                def loop_body(i, outputs):
                    subkernel_ind = kernels_ind.read(i)
                    fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2)
                    eh = ex_h + fh_ - 1
                    ew = ex_w + fw_ - 1

                    subkernel_ind = ps.reshape(ps.reshape(
                        subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] +
                                               ps.range(c_in),
                                               shape=[-1])

                    k = tf.gather(kernel, subkernel_ind, axis=-2)
                    ind, shape = im2row_index([eh, ew, c_in],
                                              block_shape=(fh_, fw_),
                                              slice_step=(1, 1),
                                              dilations=dilations)
                    x_i = x_pad[..., :eh, :ew, :]
                    x_i_shape = ps.shape(x_i)
                    flat_shape = ps.pad(x_i_shape[:-3],
                                        paddings=[[0, 1]],
                                        constant_values=-1)
                    flat_x = tf.reshape(x_i, flat_shape)
                    x_ = tf.gather(flat_x, ind, axis=-1)
                    im_x = tf.reshape(
                        x_, ps.concat([x_i_shape[:-3], shape], axis=0))
                    outputs = outputs.write(
                        i,
                        tf.matmul(
                            im_x,
                            tf.reshape(
                                k,
                                ps.concat([
                                    kernel_batch, [1, fh_ * fw_ * c_in, c_out]
                                ],
                                          axis=0))))
                    return i + 1, outputs
示例#24
0
def _canonicalize_steps_to_trace(step_indices_to_trace, num_timesteps):
    """Canonicalizes `3` -> `[3]`, `[-2, -1]` -> `[N - 2, N - 1]`, etc."""
    step_indices_to_trace = tf.convert_to_tensor(
        step_indices_to_trace,
        dtype_hint=tf.int32)  # Warning: breaks gradients.
    traced_steps_have_rank_zero = ps.equal(
        ps.rank_from_shape(ps.shape(step_indices_to_trace)), 0)
    # Canonicalize negative step indices as positive.
    step_indices_to_trace = ps.where(step_indices_to_trace < 0,
                                     num_timesteps + step_indices_to_trace,
                                     step_indices_to_trace)
    # Canonicalize scalars as length-one vectors.
    return (ps.reshape(step_indices_to_trace,
                       [ps.size(step_indices_to_trace)]),
            traced_steps_have_rank_zero)
示例#25
0
 def _fn(self, **kwargs):
   """Implements summary statistic, eg, mean, stddev, mode."""
   sample_shape = ps.reshape(self.sample_shape, shape=[-1])
   x = getattr(self.distribution, attr)(**kwargs)
   shape = ps.concat([
       self.distribution.batch_shape_tensor(),
       ps.ones(ps.rank_from_shape(sample_shape), dtype=sample_shape.dtype),
       self.distribution.event_shape_tensor(),
   ], axis=0)
   x = tf.reshape(x, shape=shape)
   shape = ps.concat([
       self.distribution.batch_shape_tensor(),
       sample_shape,
       self.distribution.event_shape_tensor(),
   ], axis=0)
   return tf.broadcast_to(x, shape)
示例#26
0
 def _sampling_permutation(self, sample_ndims):
   fake_sample_ndims = ps.rank_from_shape(
       ps.reshape(self.sample_shape, shape=[-1]))
   event_ndims = ps.rank_from_shape(
       self.distribution.event_shape_tensor, self.distribution.event_shape)
   batch_ndims = ps.rank_from_shape(
       self.distribution.batch_shape_tensor, self.distribution.batch_shape)
   return ps.concat([
       ps.range(sample_ndims),
       ps.range(sample_ndims + fake_sample_ndims,
                sample_ndims + fake_sample_ndims + batch_ndims,
                dtype=tf.int32),
       ps.range(sample_ndims, sample_ndims + fake_sample_ndims,
                dtype=tf.int32),
       ps.range(sample_ndims + fake_sample_ndims + batch_ndims,
                sample_ndims + fake_sample_ndims + batch_ndims + event_ndims,
                dtype=tf.int32),
   ], axis=0)
示例#27
0
    def _sample_n(self, n, seed=None):
        batch_shape = self.batch_shape_tensor()
        batch_rank = ps.rank_from_shape(batch_shape)
        n_batch = ps.reduce_prod(batch_shape)

        underlying_batch_shape = self.distribution.batch_shape_tensor()
        underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape)
        underlying_n_batch = ps.reduce_prod(underlying_batch_shape)

        # Left pad underlying shape with any necessary ones.
        underlying_bcast_shp = ps.concat([
            ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)],
                    dtype=underlying_batch_shape.dtype), underlying_batch_shape
        ],
                                         axis=0)

        # Determine how many underlying samples to produce.
        n_bcast_samples = ps.maximum(0, n_batch // underlying_n_batch)
        samps = self.distribution.sample([n, n_bcast_samples], seed=seed)

        is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp)

        event_shape = self.event_shape_tensor()
        event_rank = ps.rank_from_shape(event_shape)
        shp = ps.concat([[n],
                         ps.where(is_dim_bcast, batch_shape, 1),
                         underlying_bcast_shp, event_shape],
                        axis=0)
        # Reshape to expand n_bcast_samples and ones-padded underlying_bcast_shp.
        samps = tf.reshape(samps, shp)
        # Interleave broadcast and underlying axis indices for transpose.
        interleaved_batch_axes = ps.reshape(
            ps.stack([ps.range(batch_rank),
                      ps.range(batch_rank) + batch_rank],
                     axis=-1), [-1]) + 1

        event_axes = ps.range(event_rank) + (1 + 2 * batch_rank)
        perm = ps.concat([[0], interleaved_batch_axes, event_axes], axis=0)
        samps = tf.transpose(samps, perm=perm)
        # Finally, reshape to the fully-broadcast batch shape.
        return tf.reshape(samps,
                          ps.concat([[n], batch_shape, event_shape], axis=0))
示例#28
0
 def _log_prob(self, x, **kwargs):
     batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                      self.distribution.batch_shape)
     extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
     event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor,
                                      self.distribution.event_shape)
     ndims = ps.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=ps.pad(ps.shape(x),
                                 paddings=[[ps.maximum(0, -d), 0]],
                                 constant_values=1))
     ndims = ps.rank(x)
     sample_ndims = ps.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = ps.range(0, sample_ndims)
     batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims)
     extra_sample_dims = ps.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims,
                           ndims)
     perm = ps.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(a=x, perm=perm)
     # (3) Compute x's log_prob.
     lp = self.distribution.log_prob(x, **kwargs)
     # (4) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has
     #     full sample shape in the sample axes, before we reduce.
     bcast_lp_shape = ps.broadcast_shape(
         ps.shape(lp),
         ps.concat([
             ps.ones([sample_ndims], tf.int32),
             ps.reshape(self.sample_shape, shape=[-1]),
             ps.ones([batch_ndims], tf.int32)
         ],
                   axis=0))
     lp = tf.broadcast_to(lp, bcast_lp_shape)
     # (5) Make the final reduction in x.
     axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims)
     return tf.reduce_sum(lp, axis=axis)
示例#29
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        with tf.name_scope(
                mcmc_util.make_name(self.name,
                                    'diagonal_mass_matrix_adaptation',
                                    'one_step')):
            variance_parts = previous_kernel_results.running_variance
            diags = [
                variance_part.variance() for variance_part in variance_parts
            ]
            # Set the momentum.
            batch_ndims = ps.rank(
                unnest.get_innermost(previous_kernel_results,
                                     'target_log_prob'))
            state_parts = tf.nest.flatten(current_state)
            new_momentum_distribution = _make_momentum_distribution(
                diags, state_parts, batch_ndims)
            inner_results = self.momentum_distribution_setter_fn(
                previous_kernel_results.inner_results,
                new_momentum_distribution)

            # Step the inner kernel.
            inner_kwargs = {} if seed is None else dict(seed=seed)
            new_state, new_inner_results = self.inner_kernel.one_step(
                current_state, inner_results, **inner_kwargs)
            new_state_parts = tf.nest.flatten(new_state)
            new_variance_parts = []
            for variance_part, diag, state_part in zip(variance_parts, diags,
                                                       new_state_parts):
                # Compute new variance for each variance part, accounting for partial
                # batching of the variance calculation across chains (ie, some, all, or
                # none of the chains may share the estimated mass matrix).
                #
                # For example, say
                #
                # state_part has shape       [2, 3, 4] + [5, 6]  (batch + event)
                # variance_part has shape          [4] + [5, 6]
                # log_prob has shape         [2, 3, 4]
                #
                # i.e., we have a batch of chains of shape [2, 3, 4], and 4 mass
                # matrices, each being shared across a [2, 3]-batch of chains. Note this
                # division is inferred from the shapes of the state part, the log_prob,
                # and the user-provided initial running variances.
                #
                # Until RunningVariance supports rank > 1 chunking, we need to flatten
                # the states that go into updating the variance estimates. In the above
                # example, `state_part` will be reshaped to `[6, 4, 5, 6]`, and
                # fed to `RunningVariance.update(state_part, axis=0)`, recording
                # 6 new observations in the running variance calculation.
                # `RunningVariance.variance()` will then be of shape `[4, 5, 6]`, and
                # the resulting momentum distribution will have batch shape of
                # `[2, 3, 4]` and event_shape of `[5, 6]`, matching the state_part.
                state_rank = ps.rank(state_part)
                variance_rank = ps.rank(diag)
                num_reduce_dims = state_rank - variance_rank

                state_part_shape = ps.shape(state_part)
                # This reshape adds a 1 when reduce_dims==0, and collapses all the lead
                # dimensions to a single one otherwise.
                reshaped_state = ps.reshape(
                    state_part,
                    ps.concat(
                        [[ps.reduce_prod(state_part_shape[:num_reduce_dims])],
                         state_part_shape[num_reduce_dims:]],
                        axis=0))

                # The `axis=0` here removes the leading dimension we got from the
                # reshape above, so the new_variance_parts have the correct shape again.
                new_variance_parts.append(
                    variance_part.update(reshaped_state, axis=0))

            new_kernel_results = previous_kernel_results._replace(
                inner_results=new_inner_results,
                running_variance=new_variance_parts)

            return new_state, new_kernel_results
示例#30
0
    def __init__(
            self,
            input_size,
            output_size,  # keras::Conv::filters
            # Conv specific.
        filter_shape,  # keras::Conv::kernel_size
            rank=2,  # keras::Conv::rank
            strides=1,  # keras::Conv::strides
            padding='VALID',  # keras::Conv::padding; 'CAUSAL' not implemented.
            # keras::Conv::data_format is not implemented
        dilations=1,  # keras::Conv::dilation_rate
            # Weights
        init_kernel_fn=None,  # tfp.experimental.nn.initializers.glorot_uniform()
            init_bias_fn=None,  # tf.initializers.zeros()
            make_kernel_bias_fn=nn_util_lib.make_kernel_bias,
            dtype=tf.float32,
            batch_shape=(),
            # Misc
            activation_fn=None,
            name=None):
        """Constructs layer.

    Note: `data_format` is not supported since all nn layers operate on
    the rightmost column. If your channel dimension is not rightmost, use
    `tf.transpose` before calling this layer. For example, if your channel
    dimension is second from the left, the following code will move it
    rightmost:

    ```python
    inputs = tf.transpose(inputs, tf.concat([
        [0], tf.range(2, tf.rank(inputs)), [1]], axis=0))
    ```

    Args:
      input_size: ...
        In Keras, this argument is inferred from the rightmost input shape,
        i.e., `tf.shape(inputs)[-1]`. This argument specifies the size of the
        second from the rightmost dimension of both `inputs` and `kernel`.
        Default value: `None`.
      output_size: ...
        In Keras, this argument is called `filters`. This argument specifies the
        rightmost dimension size of both `kernel` and `bias`.
      filter_shape: ...
        In Keras, this argument is called `kernel_size`. This argument specifies
        the leftmost `rank` dimensions' sizes of `kernel`.
      rank: An integer, the rank of the convolution, e.g. "2" for 2D
        convolution. This argument implies the number of `kernel` dimensions,
        i.e.`, `kernel.shape.rank == rank + 2`.
        In Keras, this argument has the same name and semantics.
        Default value: `2`.
      strides: An integer or tuple/list of n integers, specifying the stride
        length of the convolution.
        In Keras, this argument has the same name and semantics.
        Default value: `1`.
      padding: One of `"VALID"` or `"SAME"` (case-insensitive).
        In Keras, this argument has the same name and semantics (except we don't
        support `"CAUSAL"`).
        Default value: `'VALID'`.
      dilations: An integer or tuple/list of `rank` integers, specifying the
        dilation rate to use for dilated convolution. Currently, specifying any
        `dilations` value != 1 is incompatible with specifying any `strides`
        value != 1.
        In Keras, this argument is called `dilation_rate`.
        Default value: `1`.
      init_kernel_fn: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      init_bias_fn: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_kernel_bias_fn: ...
        Default value: `tfp.experimental.nn.util.make_kernel_bias`.
      dtype: ...
        Default value: `tf.float32`.
      batch_shape: ...
        Default value: `()`.
      activation_fn: ...
        Default value: `None`.
      name: ...
        Default value: `None` (i.e., `'Convolution'`).
    """
        filter_shape = prepare_tuple_argument(filter_shape,
                                              rank,
                                              arg_name='filter_shape')
        batch_shape = (np.array([], dtype=np.int32) if batch_shape is None else
                       prefer_static.reshape(batch_shape, shape=[-1]))
        batch_ndims = prefer_static.size(batch_shape)
        if tf.get_static_value(batch_ndims) == 0:
            # In this branch, we statically know there are no batch dims.
            kernel_shape = filter_shape + (input_size, output_size)
            bias_shape = [output_size]
            apply_kernel_fn = _make_convolution_fn(rank, strides, padding,
                                                   dilations)
        else:
            # In this branch, there are either static/dynamic batch dims or
            # dynamically no batch dims.
            kernel_shape = prefer_static.concat(
                [batch_shape, filter_shape, [input_size, output_size]], axis=0)
            bias_shape = prefer_static.concat([batch_shape, [output_size]],
                                              axis=0)
            apply_kernel_fn = lambda x, k: convolution_batch(  # pylint: disable=g-long-lambda
                x,
                k,
                rank=rank,
                strides=strides,
                padding=padding,
                data_format='NHWBC',
                dilations=dilations)
        kernel, bias = make_kernel_bias_fn(kernel_shape, bias_shape,
                                           init_kernel_fn, init_bias_fn,
                                           batch_ndims, batch_ndims, dtype)
        self._make_kernel_bias_fn = make_kernel_bias_fn  # For tracking.
        super(Convolution, self).__init__(kernel=kernel,
                                          bias=bias,
                                          apply_kernel_fn=apply_kernel_fn,
                                          dtype=dtype,
                                          activation_fn=activation_fn,
                                          name=name)