def _sample_n(self, n, seed):
        seed = SeedStream(seed, salt='MixtureSameFamily')
        x = self.components_distribution.sample(n, seed=seed())  # [n, B, k, E]

        event_shape = None
        event_ndims = tensorshape_util.rank(self.event_shape)
        if event_ndims is None:
            event_shape = self.components_distribution.event_shape_tensor()
            event_ndims = prefer_static.rank_from_shape(event_shape)
        event_ndims_static = tf.get_static_value(event_ndims)

        num_components = None
        if event_ndims_static is not None:
            num_components = tf.compat.dimension_value(
                x.shape[-1 - event_ndims_static])
        # We could also check if num_components can be computed statically from
        # self.mixture_distribution's logits or probs.
        if num_components is None:
            num_components = tf.shape(x)[-1 - event_ndims]

        # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        mask = tf.one_hot(
            indices=self.mixture_distribution.sample(
                n, seed=seed()),  # [n, B] or [n]
            depth=num_components,
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k] or [n, k]

        # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] .
        batch_ndims = prefer_static.rank(x) - event_ndims - 1
        mask_batch_ndims = prefer_static.rank(mask) - 1
        pad_ndims = batch_ndims - mask_batch_ndims
        mask_shape = prefer_static.shape(mask)
        mask = tf.reshape(
            mask,
            shape=prefer_static.concat([
                mask_shape[:-1],
                prefer_static.ones([pad_ndims], dtype=tf.int32),
                mask_shape[-1:],
                prefer_static.ones([event_ndims], dtype=tf.int32),
            ],
                                       axis=0))

        ret = tf.reduce_sum(x * mask, axis=-1 - event_ndims)  # [n, B, E]

        if self._reparameterize:
            if event_shape is None:
                event_shape = self.components_distribution.event_shape_tensor()
            ret = self._reparameterize_sample(ret, event_shape=event_shape)

        return ret
示例#2
0
 def _finish_log_prob(self, lp, aux):
   (sample_ndims, extra_sample_ndims, batch_ndims) = aux
   # (1) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has
   #     full sample shape in the sample axes, before we reduce.
   bcast_lp_shape = ps.broadcast_shape(
       ps.shape(lp),
       ps.concat([ps.ones([sample_ndims], tf.int32),
                  ps.reshape(self.sample_shape, shape=[-1]),
                  ps.ones([batch_ndims], tf.int32)], axis=0))
   lp = tf.broadcast_to(lp, bcast_lp_shape)
   # (2) Make the final reduction.
   axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims)
   return self._sum_fn()(lp, axis=axis)
示例#3
0
 def _bcast_and_reduce_logdet(self, underlying_ldj):
   # Ensure ldj is fully broadcast in the sample dims, i.e. ensure ldj has
   # full sample shape in the sample axes, before we reduce.
   batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                    self.distribution.batch_shape)
   extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
   sample_ndims = ps.rank(underlying_ldj) - extra_sample_ndims - batch_ndims
   bcast_ldj_shape = ps.broadcast_shape(
       ps.shape(underlying_ldj),
       ps.concat([ps.ones([sample_ndims], tf.int32),
                  ps.ones([batch_ndims], tf.int32),
                  ps.reshape(self.sample_shape, shape=[-1])], axis=0))
   ldj = tf.broadcast_to(underlying_ldj, bcast_ldj_shape)
   return self._sum_fn(ldj, axis=-1 - ps.range(extra_sample_ndims))
示例#4
0
    def _sample_n(self, n, seed):
        components_seed, mix_seed = samplers.split_seed(
            seed, salt='MixtureSameFamily')
        mixture_distribution, components_distribution = (
            self._get_distributions_with_broadcast_batch_shape())
        x = components_distribution.sample(  # [n, B, k, E]
            n, seed=components_seed)

        event_ndims = ps.rank_from_shape(self.event_shape_tensor,
                                         self.event_shape)
        # We could also check if num_components can be computed statically from
        # self.mixture_distribution's logits or probs.
        num_components = ps.dimension_size(x, idx=-1 - event_ndims)

        # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        mix_sample = mixture_distribution.sample(n, seed=mix_seed)  # [n, B]
        mask = tf.one_hot(
            indices=mix_sample,  # [n, B]
            depth=num_components,
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k]

        # Pad `mask` to [n, B, k, [1]*e].
        batch_ndims = ps.rank(x) - event_ndims - 1
        mask_batch_ndims = ps.rank(mask) - 1
        pad_ndims = batch_ndims - mask_batch_ndims
        mask_shape = ps.shape(mask)
        target_shape = ps.concat([
            mask_shape[:-1],
            ps.ones([pad_ndims], dtype=tf.int32),
            mask_shape[-1:],
            ps.ones([event_ndims], dtype=tf.int32),
        ],
                                 axis=0)
        mask = tf.reshape(mask, shape=target_shape)

        if dtype_util.is_floating(x.dtype) or dtype_util.is_complex(x.dtype):
            masked = tf.math.multiply_no_nan(x, mask)
        else:
            masked = x * mask
        ret = tf.reduce_sum(masked, axis=-1 - event_ndims)  # [n, B, E]

        if self._reparameterize:
            ret = self._reparameterize_sample(
                ret, event_shape=components_distribution.event_shape_tensor())

        return ret
 def assertDistributionIsApproximatelyStandardNormal(
         self, dist, logprob_atol=1e-2, grad_atol=1e-2):
     """Verifies that dist's lps and gradients match those of Normal(0., 1.)."""
     event_ndims = ps.rank_from_shape(dist.event_shape_tensor,
                                      dist.event_shape)
     batch_ndims = ps.rank_from_shape(dist.batch_shape_tensor,
                                      dist.batch_shape)
     dist_shape = ps.concat(
         [dist.batch_shape_tensor(),
          dist.event_shape_tensor()], axis=0)
     reference_dist = tfd.Independent(tfd.Normal(loc=tf.zeros(
         dist_shape, dtype=dist.dtype),
                                                 scale=1.),
                                      reinterpreted_batch_ndims=event_ndims)
     zs = tf.reshape(
         [-4., -2., 0., 2., 4.],
         ps.concat([[5],
                    ps.ones([batch_ndims + event_ndims], dtype=np.int32)],
                   axis=0))
     zs = tf.broadcast_to(zs, ps.concat([[5], dist_shape], axis=0))
     lp_dist, grad_dist = tfp.math.value_and_gradient(dist.log_prob, zs)
     lp_reference, grad_reference = tfp.math.value_and_gradient(
         reference_dist.log_prob, zs)
     self.assertAllClose(lp_reference, lp_dist, atol=logprob_atol)
     self.assertAllClose(grad_reference, grad_dist, atol=grad_atol)
def _broadcast_to_full_batch_shape_helper(data,
                                          event_ndims,
                                          batch_shape,
                                          sample_ndims=0):
  """Broadcasts `[sample, ?, event]` to `[sample, batch, event]`."""
  if data is None:
    return None
  data_shape = ps.shape(data)
  data_rank = ps.rank_from_shape(data_shape)
  batch_ndims = ps.rank_from_shape(batch_shape)

  # Reshape the data to have full batch rank. For example, given
  # `batch_shape==[3, 2]`, this would reshape `data.shape==[S, 2, E]` to
  # `[S, 1, 2, E]`).
  # This reshaping is not necessary when `sample_ndims==0`, since with no sample
  # dimensions the batch shape itself is leftmost and can broadcast. For
  # example, we would not need to reshape `[2, E] -> [1, 2, E]`.
  if sample_ndims != 0:
    padding_ndims = batch_ndims - (data_rank - sample_ndims - event_ndims)
    padded_shape = ps.concat([data_shape[:sample_ndims],
                              ps.ones([padding_ndims], dtype=np.int32),
                              data_shape[sample_ndims:]], axis=0)
    data = tf.reshape(data, padded_shape)
    data_shape = padded_shape
    data_rank = ps.rank_from_shape(data_shape)

  # Broadcast the data to have full batch shape. For example, given
  # `batch_shape==[3, 2]`, this would broadcast `data.shape==[S, 1, 2, E]` to
  # `[S, 3, 2, E]`.
  new_shape = tf.concat([data_shape[:sample_ndims],
                         batch_shape,
                         data_shape[data_rank - event_ndims:]], axis=0)
  return tf.broadcast_to(data, new_shape)
示例#7
0
def _design_matrix_for_one_seasonal_effect(num_steps, duration, period, dtype):
  current_period = np.int32(np.arange(num_steps) / duration) % period
  return np.transpose([
      ps.where(current_period == p,  # pylint: disable=g-complex-comprehension
               ps.ones([], dtype=dtype),
               ps.zeros([], dtype=dtype))
      for p in range(period)])
示例#8
0
    def adjacent_swaps(num_replica,
                       batch_shape=(),
                       step_count=None,
                       seed=None):
        """Make random shuffle using only one time swaps."""
        del step_count  # Unused for this function.
        with tf.name_scope(name or 'adjacent_swaps'):
            parity_seed, proposal_seed = samplers.split_seed(seed)
            # u selects parity.  E.g.,
            #  u==False ==> [1, 0, 3, 2, 4] even parity swaps
            #  u==True ==>  [0, 2, 1, 4, 3] odd parity swaps
            # If there are only 2 replicas, then the "True" swaps are null
            # swaps...which would contradict the user provided `prob_swap`.
            # So special case num_replica==2, forcing u==False in this case.
            u_shape = ps.concat(
                (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)),
                axis=0)
            u = samplers.uniform(u_shape, seed=parity_seed) < 0.5
            u = tf.where(num_replica > 2, u, False)

            x = bu.left_justified_expand_dims_to(ps.range(num_replica,
                                                          dtype=tf.int64),
                                                 rank=ps.size(u_shape))
            y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1,
                         x - 1)
            y = tf.clip_by_value(y, 0, num_replica - 1)
            # TODO(b/142689785): Consider using tf.cond and returning an empty list
            # then in REMC consider using a tf.cond for short-circuiting.
            return tf.where(
                samplers.uniform(batch_shape, seed=proposal_seed) < prob_swap,
                y, x)
示例#9
0
    def _variance(self):
        probs = self.mixture_distribution.probs_parameter()  # [B, k] or [k]
        component_means = self.components_distribution.mean()  # [B, k, E]
        component_vars = self.components_distribution.variance()  # [B, k, E]
        event_ndims = self._event_ndims()

        # reshape probs to [B, k, [1]*e] or [k, [1]*e]
        probs = tf.reshape(
            probs,
            prefer_static.concat([
                prefer_static.shape(probs),
                prefer_static.ones([event_ndims], dtype=tf.int32)
            ],
                                 axis=0))

        # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
        mean_cond_var = tf.reduce_sum(probs * component_vars,
                                      axis=-1 - event_ndims)  # [B, E]
        mean = tf.reduce_sum(probs * component_means,
                             axis=-1 - event_ndims,
                             keepdims=True)  # [B, 1, E]
        var_cond_mean = tf.reduce_sum(
            probs * tf.math.squared_difference(component_means, mean),
            axis=-1 - event_ndims)  # [B, E]
        return mean_cond_var + var_cond_mean
示例#10
0
def _init_momentum(initial_transformed_position):
    """Initialize momentum so trace_fn can be concatenated."""
    event_shape = ps.shape(initial_transformed_position)[-1]
    return dmma._make_momentum_distribution(  # pylint: disable=protected-access
        running_variance_parts=[ps.ones(event_shape)],
        state_parts=tf.nest.flatten(initial_transformed_position),
        batch_ndims=1)
示例#11
0
def _expand_dims_under_batch_dim(tensor, new_rank):
    """Adds size-1 dimensions below the first until `tensor` has `new_rank`."""
    ones = prefer_static.ones([new_rank - prefer_static.rank(tensor)],
                              dtype=tf.int32)
    shape = prefer_static.shape(tensor)
    new_shape = prefer_static.concat([shape[:1], ones, shape[1:]], axis=0)
    return tf.reshape(tensor, new_shape)
def _right_pad(x, final_rank):
    """Pads the shape of x to the right to be of rank final_rank.

  Expands the dims of `x` to the right such that its rank is equal to
  final_rank. For example, if `x` is of shape [1, 5, 7, 2] and `final_rank` is
  7, we return padded_x, which is of shape [1, 5, 7, 2, 1, 1, 1].

  Args:
    x: The tensor whose shape is to be padded.
    final_rank: Scalar int32 `Tensor` or Python `int`. The desired rank of x.

  Returns:
    padded_x: A tensor of rank final_rank.
  """
    padded_shape = ps.concat(
        [ps.shape(x),
         ps.ones(final_rank - ps.rank(x), dtype=tf.int32)],
        axis=0)
    static_padded_shape = None
    if tensorshape_util.is_fully_defined(x.shape) and isinstance(
            final_rank, int):
        static_padded_shape = tensorshape_util.as_list(x.shape)
        extra_dims = final_rank - len(static_padded_shape)
        static_padded_shape.extend([1] * extra_dims)

    padded_x = tf.reshape(x, static_padded_shape or padded_shape)
    return padded_x
示例#13
0
def _init_momentum(initial_transformed_position):
  """Initialize momentum so trace_fn can be concatenated."""
  event_shape = ps.shape(initial_transformed_position)[-1]
  return preconditioning_utils.make_momentum_distribution(
      state_parts=tf.nest.flatten(initial_transformed_position),
      batch_ndims=1,
      running_variance_parts=[ps.ones(event_shape)])
示例#14
0
    def _transpose_and_reshape_result(self, x, sample_shape, event_shape=None):
        if event_shape is None:
            event_shape = self.event_shape_tensor()

        batch_shape = self.batch_shape_tensor()
        batch_rank = ps.rank_from_shape(batch_shape)

        underlying_batch_shape = self.distribution.batch_shape_tensor()
        underlying_batch_rank = ps.rank_from_shape(underlying_batch_shape)

        # Continuing the example from `_augment_sample_shape`, suppose we have:
        #   - sample shape of `[n]`,
        #   - underlying distribution batch shape of `[2, 1]`,
        #   - final broadcast batch shape of `[4, 2, 3]`.
        # and have drawn an `x` of shape `[n, 12, 2, 1] + event_shape`, which we
        # ultimately want to have shape `[n, 4, 2, 3] + event_shape`.

        # First, we reshape to expand out the batch elements:
        # `shape_with_doubled_batch == [n] + [4, 1, 3] + [1, 2, 1] + event_shape`,
        # where `[1, 2, 1]` is the fully-expanded underlying batch shape, and
        # `[4, 1, 3]` is the shape of the elements being added by broadcasting.
        underlying_bcast_shp = ps.concat([
            ps.ones([ps.maximum(batch_rank - underlying_batch_rank, 0)],
                    dtype=underlying_batch_shape.dtype), underlying_batch_shape
        ],
                                         axis=0)
        is_dim_bcast = ps.not_equal(batch_shape, underlying_bcast_shp)
        x_with_doubled_batch = tf.reshape(
            x,
            ps.concat([
                sample_shape,
                ps.where(is_dim_bcast, batch_shape, 1), underlying_bcast_shp,
                event_shape
            ],
                      axis=0))

        # Next, construct the permutation that interleaves the batch dimensions,
        # resulting in samples with shape
        # `[n] + [4, 1] + [1, 2] + [3, 1] + event_shape`.
        # Note that each interleaved pair of batch dimensions contains exactly one
        # dim of size `1` and one of size `>= 1`.
        sample_ndims = ps.rank_from_shape(sample_shape)
        x_with_interleaved_batch = tf.transpose(
            x_with_doubled_batch,
            perm=ps.concat([
                ps.range(sample_ndims), sample_ndims + ps.reshape(
                    ps.stack([
                        ps.range(batch_rank),
                        ps.range(batch_rank) + batch_rank
                    ],
                             axis=-1), [-1]), sample_ndims + 2 * batch_rank +
                ps.range(ps.rank_from_shape(event_shape))
            ],
                           axis=0))

        # Final reshape to remove the spurious `1` dimensions.
        return tf.reshape(
            x_with_interleaved_batch,
            ps.concat([sample_shape, batch_shape, event_shape], axis=0))
示例#15
0
def _slice_single_param(param, param_event_ndims, slices, dist_batch_shape):
  """Slices a single parameter of a distribution.

  Args:
    param: A `Tensor`, the original parameter to slice.
    param_event_ndims: `int` event parameterization rank for this parameter.
    slices: A `tuple` of normalized slices.
    dist_batch_shape: The distribution's batch shape `Tensor`.

  Returns:
    new_param: A `Tensor`, batch-sliced according to slices.
  """
  # Extend param shape with ones on the left to match dist_batch_shape.
  param_shape = ps.shape(param)
  insert_ones = ps.ones(
      [ps.size(dist_batch_shape) + param_event_ndims - ps.rank(param)],
      dtype=param_shape.dtype)
  new_param_shape = ps.concat([insert_ones, param_shape], axis=0)
  full_batch_param = tf.reshape(param, new_param_shape)
  param_slices = []
  # We separately track the batch axis from the parameter axis because we want
  # them to align for positive indexing, and be offset by param_event_ndims for
  # negative indexing.
  param_dim_idx = 0
  batch_dim_idx = 0
  for slc in slices:
    if slc is tf.newaxis:
      param_slices.append(slc)
      continue
    if slc is Ellipsis:
      if batch_dim_idx < 0:
        raise ValueError('Found multiple `...` in slices {}'.format(slices))
      param_slices.append(slc)
      # Switch over to negative indexing for the broadcast check.
      num_remaining_non_newaxis_slices = sum(
          [s is not tf.newaxis for s in slices[slices.index(Ellipsis) + 1:]])
      batch_dim_idx = -num_remaining_non_newaxis_slices
      param_dim_idx = batch_dim_idx - param_event_ndims
      continue
    # Find the batch dimension sizes for both parameter and distribution.
    param_dim_size = new_param_shape[param_dim_idx]
    batch_dim_size = dist_batch_shape[batch_dim_idx]
    is_broadcast = batch_dim_size > param_dim_size
    # Slices are denoted by start:stop:step.
    if isinstance(slc, slice):
      start, stop, step = slc.start, slc.stop, slc.step
      if start is not None:
        start = ps.where(is_broadcast, 0, start)
      if stop is not None:
        stop = ps.where(is_broadcast, 1, stop)
      if step is not None:
        step = ps.where(is_broadcast, 1, step)
      param_slices.append(slice(start, stop, step))
    else:  # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
      param_slices.append(ps.where(is_broadcast, 0, slc))
    param_dim_idx += 1
    batch_dim_idx += 1
  param_slices.extend([ALL_SLICE] * param_event_ndims)
  return full_batch_param.__getitem__(tuple(param_slices))
示例#16
0
def left_justified_expand_dims_to(x, rank, name=None):
  """Right pads `x` with `rank - rank(x)` ones."""
  with tf.name_scope(name or 'left_justified_expand_dims_to'):
    expand_ndims = ps.maximum(rank - ps.rank(x), 0)
    expand_shape = ps.concat(
        [ps.shape(x),
         ps.ones(shape=[expand_ndims], dtype=tf.int32)],
        axis=0)
    return tf.reshape(x, expand_shape)
示例#17
0
 def expand_right_dims(x, broadcast=False):
   """Expand x so it can bcast w/ tensors of output shape."""
   expanded_shape_left = ps.broadcast_shape(
       ps.shape(x)[:-1],
       ps.ones([ps.size(y_ref_shape_left)], dtype=tf.int32))
   expanded_shape = ps.concat(
       (expanded_shape_left, ps.shape(x)[-1:],
        ps.ones([ps.size(y_ref_shape_right)], dtype=tf.int32)),
       axis=0)
   x_expanded = tf.reshape(x, expanded_shape)
   if broadcast:
     broadcast_shape_left = ps.broadcast_shape(
         ps.shape(x)[:-1], y_ref_shape_left)
     broadcast_shape = ps.concat(
         (broadcast_shape_left, ps.shape(x)[-1:], y_ref_shape_right),
         axis=0)
     x_expanded = _broadcast_with(x_expanded, broadcast_shape)
   return x_expanded
示例#18
0
def _add_event_dims_to_mask(validity_mask, *, dist=None, event_ndims=None):
    validity_mask = tf.convert_to_tensor(validity_mask)
    if event_ndims is None:
        event_ndims = ps.rank_from_shape(dist.event_shape_tensor())
    return tf.reshape(
        validity_mask,
        ps.concat(
            [ps.shape(validity_mask),
             ps.ones(event_ndims, dtype=tf.int32)],
            axis=0))
示例#19
0
 def _log_prob(self, x, **kwargs):
     batch_ndims = ps.rank_from_shape(self.distribution.batch_shape_tensor,
                                      self.distribution.batch_shape)
     extra_sample_ndims = ps.rank_from_shape(self.sample_shape)
     event_ndims = ps.rank_from_shape(self.distribution.event_shape_tensor,
                                      self.distribution.event_shape)
     ndims = ps.rank(x)
     # (1) Expand x's dims.
     d = ndims - batch_ndims - extra_sample_ndims - event_ndims
     x = tf.reshape(x,
                    shape=ps.pad(ps.shape(x),
                                 paddings=[[ps.maximum(0, -d), 0]],
                                 constant_values=1))
     ndims = ps.rank(x)
     sample_ndims = ps.maximum(0, d)
     # (2) Transpose x's dims.
     sample_dims = ps.range(0, sample_ndims)
     batch_dims = ps.range(sample_ndims, sample_ndims + batch_ndims)
     extra_sample_dims = ps.range(
         sample_ndims + batch_ndims,
         sample_ndims + batch_ndims + extra_sample_ndims)
     event_dims = ps.range(sample_ndims + batch_ndims + extra_sample_ndims,
                           ndims)
     perm = ps.concat(
         [sample_dims, extra_sample_dims, batch_dims, event_dims], axis=0)
     x = tf.transpose(a=x, perm=perm)
     # (3) Compute x's log_prob.
     lp = self.distribution.log_prob(x, **kwargs)
     # (4) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has
     #     full sample shape in the sample axes, before we reduce.
     bcast_lp_shape = ps.broadcast_shape(
         ps.shape(lp),
         ps.concat([
             ps.ones([sample_ndims], tf.int32),
             ps.reshape(self.sample_shape, shape=[-1]),
             ps.ones([batch_ndims], tf.int32)
         ],
                   axis=0))
     lp = tf.broadcast_to(lp, bcast_lp_shape)
     # (5) Make the final reduction in x.
     axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims)
     return tf.reduce_sum(lp, axis=axis)
示例#20
0
def left_justified_expand_dims_to(x, rank, name=None):
    """Right pads `x` with `rank - rank(x)` ones."""
    with tf.name_scope(name or 'left_justified_expand_dims_to'):
        rank = tf.convert_to_tensor(rank, dtype=tf.int32)
        expand_ndims = prefer_static.maximum(rank - prefer_static.rank(x), 0)
        expand_shape = prefer_static.concat([
            prefer_static.shape(x),
            prefer_static.ones(shape=[expand_ndims], dtype=tf.int32)
        ],
                                            axis=0)
        return prefer_static.reshape(x, expand_shape)
示例#21
0
 def _broadcast_transition_probs(self, sample_and_batch_shape) -> tf.Tensor:
     transition_probs_shape = ps.shape(self.transition_probs_tree.branch_lengths)
     transition_probs_batch_shape = transition_probs_shape[:-3]
     additional_dims = (
         ps.shape(sample_and_batch_shape)[0]
         - ps.shape(transition_probs_batch_shape)[0]
     )
     new_shape = ps.concat(
         [ps.ones(additional_dims, dtype=tf.int32), transition_probs_shape], axis=0
     )
     return tf.reshape(self.transition_probs_tree.branch_lengths, new_shape)
示例#22
0
def _dummy_indices_like(indices):
    """Returns dummy indices ([0, 1, 2, ...]) with batch shape like `indices`."""
    indices_shape = ps.shape(indices)
    num_particles = indices_shape[0]
    return tf.broadcast_to(
        ps.reshape(
            ps.range(num_particles),
            ps.concat([[num_particles],
                       ps.ones([ps.rank_from_shape(indices_shape) - 1],
                               dtype=np.int32)],
                      axis=0)), indices_shape)
示例#23
0
 def do_padding(observed_time_series_tensor):
     current_sample_shape = ps.shape(
         observed_time_series_tensor)[:-(model_batch_ndims + event_ndims)]
     current_batch_and_event_shape = ps.shape(
         observed_time_series_tensor)[-(model_batch_ndims + event_ndims):]
     return tf.reshape(tensor=observed_time_series_tensor,
                       shape=ps.concat([
                           current_sample_shape,
                           ps.ones([chain_batch_ndims], dtype=tf.int32),
                           current_batch_and_event_shape
                       ],
                                       axis=0))
示例#24
0
    def even_odd_swaps(num_replica,
                       batch_shape=(),
                       step_count=None,
                       seed=None):
        """Make deterministic even_odd one time swaps."""
        if step_count is None:
            raise ValueError('`step_count` must be supplied. Found `None`.')
        del seed  # Unused for this function.
        with tf.name_scope(name or 'even_odd_swaps'):
            # Period is 1 / frequency, and we want period = Inf if frequency = 0.
            # safe_swap_period is the correct swap period in case swap_frequency > 0.
            # If swap_frequency == 0, safe_swap_period is set to 1 (to avoid integer
            # div by zero below). We will hard-set this case to "null swap."
            swap_freq = tf.convert_to_tensor(swap_frequency,
                                             name='swap_frequency')
            safe_swap_period = tf.cast(
                tf.where(swap_freq > 0,
                         tf.math.ceil(tf.math.reciprocal_no_nan(swap_freq)),
                         1),
                # Although period = 1 / frequency may have roundoff error, and result
                # in a period different than what the user intended, the
                # user will end up with a single integer period, and thus well defined
                # deterministic swaps.
                tf.int32,
            )

            # u selects parity.  E.g.,
            #  u==False ==> [1, 0, 3, 2, 4] even parity swaps
            #  u==True ==>  [0, 2, 1, 4, 3] odd parity swaps
            # If there are 2 replicas, then the "True" swaps are null
            # swaps...which would contradict the user provided `swap_frequency`.
            # So special case num_replica==2, forcing u==False in this case.
            u_shape = ps.concat(
                (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)),
                axis=0)
            u = tf.fill(u_shape,
                        tf.cast((step_count // safe_swap_period) % 2, tf.bool))
            u = tf.where(num_replica > 2, u, False)

            x = bu.left_justified_expand_dims_to(tf.range(num_replica,
                                                          dtype=tf.int64),
                                                 rank=ps.size(u_shape))
            y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1,
                         x - 1)
            y = tf.clip_by_value(y, 0, num_replica - 1)
            # TODO(b/142689785): Consider using tf.cond and returning an empty list
            # then in REMC consider using a tf.cond for short-circuiting.
            return tf.where(
                (tf.cast(step_count % safe_swap_period, tf.bool)
                 | tf.math.equal(swap_freq, 0)),
                x,  # Don't swap
                y,  # Swap
            )
示例#25
0
  def _mean(self):
    probs = self.mixture_distribution.probs_parameter()  # [B, k] or [k]
    component_means = self.components_distribution.mean()  # [B, k, E]
    event_ndims = self._event_ndims()

    # reshape probs to [B, k, [1]*e] or [k, [1]*e]
    probs = tf.reshape(probs, ps.concat([
        ps.shape(probs),
        ps.ones([event_ndims], dtype=tf.int32)
    ], axis=0))

    return tf.reduce_sum(probs * component_means,
                         axis=-1 - event_ndims)  # [B, E]
 def weighted_reduce_sum(x, axis=0):
     """Weighted sum over an axis of `x`."""
     # Extend the weights to broadcast over any event dimensions of `x`.
     # This assumes that `weights` and `x` have the same sample and batch
     # dimensions, e.g., that they come from the same `sample_and_log_prob` call.
     event_ndims = ps.rank(x) - ps.rank(weights)
     aligned_weights = tf.reshape(
         weights,
         ps.concat(
             [ps.shape(weights),
              ps.ones([event_ndims], dtype=tf.int32)],
             axis=0))
     return tf.reduce_sum(aligned_weights * tf.cast(x, weights.dtype),
                          axis=axis)
示例#27
0
 def _fn(self):
   """Implements summary statistic, eg, mean, stddev, mode."""
   x = getattr(self.distribution, attr)()
   shape = prefer_static.concat([
       self.distribution.batch_shape_tensor(),
       prefer_static.ones(prefer_static.rank_from_shape(self.sample_shape),
                          dtype=self.sample_shape.dtype),
       self.distribution.event_shape_tensor(),
   ], axis=0)
   x = tf.reshape(x, shape=shape)
   shape = prefer_static.concat([
       self.distribution.batch_shape_tensor(),
       self.sample_shape,
       self.distribution.event_shape_tensor(),
   ], axis=0)
   return tf.broadcast_to(x, shape)
示例#28
0
 def _fn(self, **kwargs):
   """Implements summary statistic, eg, mean, stddev, mode."""
   sample_shape = ps.reshape(self.sample_shape, shape=[-1])
   x = getattr(self.distribution, attr)(**kwargs)
   shape = ps.concat([
       self.distribution.batch_shape_tensor(),
       ps.ones(ps.rank_from_shape(sample_shape), dtype=sample_shape.dtype),
       self.distribution.event_shape_tensor(),
   ], axis=0)
   x = tf.reshape(x, shape=shape)
   shape = ps.concat([
       self.distribution.batch_shape_tensor(),
       sample_shape,
       self.distribution.event_shape_tensor(),
   ], axis=0)
   return tf.broadcast_to(x, shape)
    def test_dynamic_shape(self):
        x = tf.Variable(ps.ones([7, 3]), shape=[7, None])
        self.evaluate(x.initializer)

        # Check that the shape is actually `None`.
        if not tf.executing_eagerly():
            last_shape = x.shape[-1]
            if last_shape is not None:  # This is a `tf.Dimension` in tf1.
                last_shape = last_shape.value
            self.assertIsNone(last_shape)
        dynamic_dist = tfd_e.MultivariateNormalPrecisionFactorLinearOperator(
            precision_factor=tf.linalg.LinearOperatorDiag(tf.ones_like(x)))
        static_dist = tfd_e.MultivariateNormalPrecisionFactorLinearOperator(
            precision_factor=tf.linalg.LinearOperatorDiag(tf.ones([7, 3])))
        in_ = tf.zeros([7, 3])
        self.assertAllClose(self.evaluate(dynamic_dist.log_prob(in_)),
                            static_dist.log_prob(in_))
示例#30
0
    def _mean(self):
        mixture_distribution, components_distribution = (
            self._get_distributions_with_broadcast_batch_shape())
        probs = mixture_distribution.probs_parameter()  # [B, k]
        component_means = components_distribution.mean()  # [B, k, E]
        event_ndims = self._event_ndims()

        # reshape probs to [B, k, [1]*e]
        probs = tf.reshape(
            probs,
            ps.concat(
                [ps.shape(probs),
                 ps.ones([event_ndims], dtype=tf.int32)],
                axis=0))

        return tf.reduce_sum(probs * component_means,
                             axis=-1 - event_ndims)  # [B, E]