def moments_of_masked_time_series(time_series_tensor, broadcast_mask):
    """Compute mean and variance, accounting for a mask.

  Args:
    time_series_tensor: float `Tensor` time series of shape
      `concat([batch_shape, [num_timesteps]])`.
    broadcast_mask: bool `Tensor` of the same shape as `time_series`.
  Returns:
    mean: float `Tensor` of shape `batch_shape`.
    variance: float `Tensor` of shape `batch_shape`.
  """
    num_unmasked_entries = ps.cast(
        ps.reduce_sum(ps.cast(~broadcast_mask, np.int32), axis=-1),
        time_series_tensor.dtype)

    # Manually compute mean and variance, excluding masked entries.
    mean = (tf.reduce_sum(tf.where(
        broadcast_mask, tf.zeros([], dtype=time_series_tensor.dtype),
        time_series_tensor),
                          axis=-1) / num_unmasked_entries)
    variance = (tf.reduce_sum(tf.where(
        broadcast_mask, tf.zeros([], dtype=time_series_tensor.dtype),
        (time_series_tensor - mean[..., tf.newaxis])**2),
                              axis=-1) / num_unmasked_entries)
    return mean, variance
예제 #2
0
def _axis_size(x, axis=None):
    """Get number of elements of `x` in `axis`, as type `x.dtype`."""
    if axis is None:
        return prefer_static.cast(prefer_static.size(x), x.dtype)
    return prefer_static.cast(
        prefer_static.reduce_prod(
            prefer_static.gather(prefer_static.shape(x), axis)), x.dtype)
def iid_sample(sample_fn, sample_shape):
  """Lift a sampling function to one that draws multiple iid samples.

  Args:
    sample_fn: Python `callable` that returns a (possibly nested) structure of
      `Tensor`s. May optionally take a `seed` named arg: if so, any `int`
      seeds (for stateful samplers) are passed through directly, while any
      pair-of-`int` seeds (for stateless samplers) are split into independent
      seeds for each sample.
    sample_shape: `int` `Tensor` shape of iid samples to draw.
  Returns:
    iid_sample_fn: Python `callable` taking the same arguments as `sample_fn`
      and returning iid samples. Each returned `Tensor` will have shape
      `concat([sample_shape, shape_of_original_returned_tensor])`.
  """
  sample_shape = distribution_util.expand_to_vector(
      ps.cast(sample_shape, np.int32), tensor_name='sample_shape')
  n = ps.cast(ps.reduce_prod(sample_shape), dtype=np.int32)

  def unflatten(x):
    unflattened_shape = ps.cast(
        ps.concat([sample_shape, ps.shape(x)[1:]], axis=0),
        dtype=np.int32)
    return tf.reshape(x, unflattened_shape)

  def iid_sample_fn(*args, **kwargs):
    """Draws iid samples from `fn`."""

    with tf.name_scope('iid_sample_fn'):

      seed = kwargs.pop('seed', None)
      if samplers.is_stateful_seed(seed):
        kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')())
        def pfor_loop_body(_):
          with tf.name_scope('iid_sample_fn_stateful_body'):
            return sample_fn(*args, **kwargs)
      else:
        # If a stateless seed arg is passed, split it into `n` different
        # stateless seeds, so that we don't just get a bunch of copies of the
        # same sample.
        if not JAX_MODE:
          warnings.warn(
              'Saw Tensor seed {}, implying stateless sampling. Autovectorized '
              'functions that use stateless sampling may be quite slow because '
              'the current implementation falls back to an explicit loop. This '
              'will be fixed in the future. For now, you will likely see '
              'better performance from stateful sampling, which you can invoke '
              'by passing a Python `int` seed.'.format(seed))
        seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless')
        def pfor_loop_body(i):
          with tf.name_scope('iid_sample_fn_stateless_body'):
            return sample_fn(*args, seed=tf.gather(seed, i), **kwargs)

      draws = parallel_for.pfor(pfor_loop_body, n)
      return tf.nest.map_structure(unflatten, draws, expand_composites=True)

  return iid_sample_fn
예제 #4
0
def iid_sample(sample_fn, sample_shape):
    """Lift a sampling function to one that draws multiple iid samples.

  Args:
    sample_fn: Python `callable` that returns a (possibly nested) structure of
      `Tensor`s. May optionally take a `seed` named arg: if so, any `int`
      seeds (for stateful samplers) are passed through directly, while any
      pair-of-`int` seeds (for stateless samplers) are split into independent
      seeds for each sample.
    sample_shape: `int` `Tensor` shape of iid samples to draw.
  Returns:
    iid_sample_fn: Python `callable` taking the same arguments as `sample_fn`
      and returning iid samples. Each returned `Tensor` will have shape
      `concat([sample_shape, shape_of_original_returned_tensor])`.
  """
    sample_shape = distribution_util.expand_to_vector(
        prefer_static.cast(sample_shape, np.int32), tensor_name='sample_shape')
    n = prefer_static.cast(prefer_static.reduce_prod(sample_shape),
                           dtype=np.int32)

    def unflatten(x):
        unflattened_shape = prefer_static.cast(prefer_static.concat(
            [sample_shape, prefer_static.shape(x)[1:]], axis=0),
                                               dtype=np.int32)
        return tf.reshape(x, unflattened_shape)

    def iid_sample_fn(*args, **kwargs):
        """Draws iid samples from `fn`."""

        pfor_loop_body = lambda _: sample_fn(*args, **kwargs)

        seed = kwargs.pop('seed', None)
        try:  # Assume that `seed` is a valid stateful seed (Python `int`).
            kwargs = dict(kwargs, seed=SeedStream(seed, salt='iid_sample')())
            pfor_loop_body = lambda _: sample_fn(*args, **kwargs)
        except TypeError as e:
            # If a stateless seed arg is passed, split it into `n` different stateless
            # seeds, so that we don't just get a bunch of copies of the same sample.
            if TENSOR_SEED_MSG_PREFIX not in str(e):
                raise
            warnings.warn(
                'Saw non-`int` seed {}, implying stateless sampling. '
                'Autovectorized functions that use stateless sampling '
                'may be quite slow because the current implementation '
                'falls back to an explicit loop. This will be fixed in the '
                'future. For now, you will likely see better performance '
                'from stateful sampling, which you can invoke by passing a'
                'traditional Python `int` seed.'.format(seed))
            seed = samplers.split_seed(seed, n=n, salt='iid_sample_stateless')
            pfor_loop_body = (
                lambda i: sample_fn(*args, seed=tf.gather(seed, i), **kwargs))

        draws = parallel_for.pfor(pfor_loop_body, n)
        return tf.nest.map_structure(unflatten, draws, expand_composites=True)

    return iid_sample_fn
예제 #5
0
def _summarize_fans(fan_in, fan_out, mode, dtype):
    """Combines `fan_in`, `fan_out` per specified `mode`."""
    fan_in = prefer_static.cast(fan_in, dtype)
    fan_out = prefer_static.cast(fan_out, dtype)
    mode = str(mode).lower()
    if mode == 'fan_in':
        return fan_in
    elif mode == 'fan_out':
        return fan_out
    elif mode == 'fan_avg':
        return (fan_in + fan_out) / 2.
    raise ValueError('Unrecognized mode: "{}".'.format(mode))
예제 #6
0
 def _calculate_batch_shape(self):
   """Computes fully defined batch shape for the new distribution."""
   all_batch_shapes = [d.batch_shape.as_list()
                       if tensorshape_util.is_fully_defined(d.batch_shape)
                       else d.batch_shape_tensor() for d in self.distributions]
   original_shape = ps.stack(all_batch_shapes, axis=0)
   index_mask = ps.cast(
       ps.one_hot(self._axis, ps.shape(original_shape)[1]),
       dtype=tf.bool)
   new_concat_dim = ps.cast(
       ps.reduce_sum(original_shape, axis=0)[self._axis], dtype=tf.int32)
   return ps.where(index_mask, new_concat_dim,
                   ps.reduce_max(original_shape, axis=0))
예제 #7
0
def _validate_elem_length(max_num_levels, elems_flat):
  """Checks that elems all have the same length, and returns that length."""
  assertions = []

  elem_length = prefer_static.shape(elems_flat[0])[0]

  # The default size limit will overflow a 32-bit int, so make sure we're
  # using 64-bit.
  size_limit = 2**(prefer_static.cast(max_num_levels, np.int64) + 1)
  enough_levels = prefer_static.less(
      prefer_static.cast(elem_length, np.int64), size_limit)
  enough_levels_ = tf.get_static_value(enough_levels)
  if enough_levels_ is None:
    assertions.append(
        tf.debugging.assert_equal(
            enough_levels, True,
            message='Input `Tensor`s must have first axis dimension less than'
            ' `2**(max_num_levels + 1)`'
            ' (saw: {} which is not less than 2**{} == {})'.format(
                elem_length,
                max_num_levels,
                size_limit)))
  elif not enough_levels_:
    raise ValueError(
        'Input `Tensor`s must have first axis dimension less than'
        ' `2**(max_num_levels + 1)`'
        ' (saw: {} which is not less than 2**{} == {})'.format(
            elem_length,
            max_num_levels,
            size_limit))

  is_consistent = prefer_static.reduce_all([
      prefer_static.equal(
          prefer_static.shape(elem)[0], elem_length)
      for elem in elems_flat[1:]])

  is_consistent_ = tf.get_static_value(is_consistent)
  if is_consistent_ is None:
    assertions.append(
        tf.debugging.assert_equal(
            is_consistent, True,
            message='Input `Tensor`s must have the same first dimension.'
            ' (saw: {})'.format([elem.shape for elem in elems_flat])))
  elif not is_consistent_:
    raise ValueError(
        'Input `Tensor`s must have the same first dimension.'
        ' (saw: {})'.format([elem.shape for elem in elems_flat]))
  return elem_length, assertions
예제 #8
0
def expand_dims(x, axis, name=None):
    """Like `tf.expand_dims` but accepts a vector of axes to expand."""
    with tf.name_scope(name or 'expand_dims'):
        x = tf.convert_to_tensor(x, name='x')
        axis = tf.convert_to_tensor(axis, dtype_hint=tf.int32, name='axis')
        nx = prefer_static.rank(x)
        na = prefer_static.size(axis)
        is_neg_axis = axis < 0
        k = prefer_static.reduce_sum(
            prefer_static.cast(is_neg_axis, axis.dtype))
        axis = prefer_static.where(is_neg_axis, axis + nx, axis)
        axis = prefer_static.sort(axis)
        axis_neg, axis_pos = prefer_static.split(axis, [k, -1])
        idx = prefer_static.argsort(prefer_static.concat([
            axis_pos,
            prefer_static.range(nx),
            axis_neg,
        ],
                                                         axis=0),
                                    stable=True)
        shape = prefer_static.pad(prefer_static.shape(x),
                                  paddings=[[na - k, k]],
                                  constant_values=1)
        shape = prefer_static.gather(shape, idx)
        return tf.reshape(x, shape)
예제 #9
0
 def _reshape_part(part):
   part = tf.cast(part, dtype)
   new_shape = ps.concat(
       [batch_shape, [-1]],
       axis=-1,
   )
   return tf.reshape(part, ps.cast(new_shape, tf.int32))
예제 #10
0
def _update_loop_variables(step, current_step_results,
                           accumulated_traced_results, trace_fn,
                           step_indices_to_trace, num_steps_traced):
    """Update the loop state to reflect a step of filtering."""

    # Write particles, indices, and likelihoods to their respective arrays.
    trace_this_step = True
    if step_indices_to_trace is not None:
        trace_this_step = ps.equal(
            step_indices_to_trace[ps.minimum(
                num_steps_traced,
                ps.cast(ps.size0(step_indices_to_trace) - 1, dtype=np.int32))],
            step)
    num_steps_traced, accumulated_traced_results = ps.cond(
        trace_this_step,
        lambda: (
            num_steps_traced + 1,  # pylint: disable=g-long-lambda
            tf.nest.map_structure(lambda x, y: x.write(num_steps_traced, y),
                                  accumulated_traced_results,
                                  trace_fn(current_step_results))),
        lambda: (num_steps_traced, accumulated_traced_results))

    return ParticleFilterLoopVariables(
        step=step + 1,
        previous_step_results=current_step_results,
        accumulated_traced_results=accumulated_traced_results,
        num_steps_traced=num_steps_traced)
예제 #11
0
 def _reshape_part(part, dtype, event_shape):
   part = tf.cast(part, dtype)
   static_rank = tf.get_static_value(ps.rank_from_shape(event_shape))
   if static_rank == 1:
     return part
   new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1)
   return tf.reshape(part, ps.cast(new_shape, tf.int32))
예제 #12
0
def _im2row_index(input_shape,
                  block_shape,
                  slice_step=(1, 1),
                  data_format='NHWC',
                  padding='VALID',
                  dtype=tf.int64,
                  name=None):
    """Computes indexes into a flattened image for building `im2col`."""
    with tf.name_scope(name or 'im2row_index'):
        # 1) Process input arguments.
        batch_shape, s3, s2, s1 = prefer_static.split(
            prefer_static.cast(input_shape, tf.int32),
            num_or_size_splits=[-1, 1, 1, 1])
        fh, fw = _split_pair(block_shape)
        sh, sw = _split_pair(slice_step)
        data_format = _validate_data_format(data_format)
        padding = _validate_padding(padding)

        # 2) Assemble all block start positions as indexes into the flattened image.
        if data_format == 'NHWC':
            h, w, c = s3[0], s2[0], s1[0]
            # start_idx.shape = [fh, fw, c]
            start_idx = _cartesian_add([
                prefer_static.range(c * w * fh, delta=c * w, dtype=dtype),
                prefer_static.range(c * fw, delta=c, dtype=dtype),
                prefer_static.range(c, delta=1, dtype=dtype),
            ])
        elif data_format == 'NCHW':
            c, h, w = s3[0], s2[0], s1[0]
            # start_idx.shape = [c, fh, fw]
            start_idx = _cartesian_add([
                prefer_static.range(w * h * c, delta=w * h, dtype=dtype),
                prefer_static.range(w * fh, delta=w, dtype=dtype),
                prefer_static.range(fw, delta=1, dtype=dtype),
            ])
        else:
            assert False  # Can't be here.

        # 3) Assemble all block offsets (into flattened image).
        if padding == 'VALID':
            eh = h - fh + 1  # extent height
            ew = w - fw + 1  # extent width
            # offset_idx.shape = [eh // sh, ew // sw]
            offset_idx = _cartesian_add([
                prefer_static.range(w * eh, delta=w * sh, dtype=dtype),
                prefer_static.range(ew, delta=sw, dtype=dtype),
            ])
            if data_format == 'NHWC':
                offset_idx *= c
            oh = eh // sh  # out height
            ow = ew // sw  # out width
        else:
            assert False  # Can't be here.

        # 4) Combine block start/offset pairs.
        # shape = [(eh // sh) * (ew // sw), fh * fw * c]
        idx = _cartesian_add([offset_idx, start_idx])
        new_shape = [oh, ow, fh * fw * c]
        new_shape = prefer_static.concat([batch_shape, new_shape], axis=0)
        return idx, new_shape
예제 #13
0
    def adjacent_swaps(num_replica,
                       batch_shape=(),
                       step_count=None,
                       seed=None):
        """Make random shuffle using only one time swaps."""
        del step_count  # Unused for this function.
        with tf.name_scope(name or 'adjacent_swaps'):
            parity_seed, proposal_seed = samplers.split_seed(seed)
            # u selects parity.  E.g.,
            #  u==False ==> [1, 0, 3, 2, 4] even parity swaps
            #  u==True ==>  [0, 2, 1, 4, 3] odd parity swaps
            # If there are only 2 replicas, then the "True" swaps are null
            # swaps...which would contradict the user provided `prob_swap`.
            # So special case num_replica==2, forcing u==False in this case.
            u_shape = ps.concat(
                (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)),
                axis=0)
            u = samplers.uniform(u_shape, seed=parity_seed) < 0.5
            u = tf.where(num_replica > 2, u, False)

            x = bu.left_justified_expand_dims_to(ps.range(num_replica,
                                                          dtype=tf.int64),
                                                 rank=ps.size(u_shape))
            y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1,
                         x - 1)
            y = tf.clip_by_value(y, 0, num_replica - 1)
            # TODO(b/142689785): Consider using tf.cond and returning an empty list
            # then in REMC consider using a tf.cond for short-circuiting.
            return tf.where(
                samplers.uniform(batch_shape, seed=proposal_seed) < prob_swap,
                y, x)
예제 #14
0
 def _multi_gamma_sequence(self, a, p, name='multi_gamma_sequence'):
     """Creates sequence used in multivariate (di)gamma; shape = shape(a)+[p]."""
     with tf.name_scope(name):
         # Linspace only takes scalars, so we'll add in the offset afterwards.
         seq = ps.linspace(tf.constant(0., dtype=self.dtype), 0.5 - 0.5 * p,
                           ps.cast(p, tf.int32))
         return seq + a[..., tf.newaxis]
예제 #15
0
 def expand_dims_(x):
     """Implementation of `expand_dims`."""
     with tf.name_scope(name or 'expand_dims'):
         x = tf.convert_to_tensor(x, name='x')
         new_axis = tf.convert_to_tensor(axis,
                                         dtype_hint=tf.int32,
                                         name='axis')
         nx = prefer_static.rank(x)
         na = prefer_static.size(new_axis)
         is_neg_axis = new_axis < 0
         k = prefer_static.reduce_sum(
             prefer_static.cast(is_neg_axis, new_axis.dtype))
         new_axis = prefer_static.where(is_neg_axis, new_axis + nx,
                                        new_axis)
         new_axis = prefer_static.sort(new_axis)
         axis_neg, axis_pos = prefer_static.split(new_axis, [k, -1])
         idx = prefer_static.argsort(prefer_static.concat([
             axis_pos,
             prefer_static.range(nx),
             axis_neg,
         ],
                                                          axis=0),
                                     stable=True)
         shape = prefer_static.pad(prefer_static.shape(x),
                                   paddings=[[na - k, k]],
                                   constant_values=1)
         shape = prefer_static.gather(shape, idx)
         return tf.reshape(x, shape)
예제 #16
0
    def _apply_with_distance(self,
                             x1,
                             x2,
                             pairwise_square_distance,
                             example_ndims=0):
        exponent = -2. * pairwise_square_distance
        locs = util.pad_shape_with_ones(self.locs,
                                        ndims=example_ndims,
                                        start=-(self.feature_ndims + 1))
        cos_coeffs = tf.math.cos(2 * np.pi * (x1 - x2) * locs)
        feature_ndims = ps.cast(self.feature_ndims, ps.rank(cos_coeffs).dtype)
        reduction_axes = ps.range(
            ps.rank(cos_coeffs) - feature_ndims, ps.rank(cos_coeffs))
        coeff_sign = tf.math.reduce_prod(tf.math.sign(cos_coeffs),
                                         axis=reduction_axes)
        log_cos_coeffs = tf.math.reduce_sum(tf.math.log(
            tf.math.abs(cos_coeffs)),
                                            axis=reduction_axes)

        logits = util.pad_shape_with_ones(self.logits,
                                          ndims=example_ndims,
                                          start=-1)

        log_result, sign = tfp_math.reduce_weighted_logsumexp(
            exponent + log_cos_coeffs + logits,
            coeff_sign,
            return_sign=True,
            axis=-(example_ndims + 1))

        return sign * tf.math.exp(log_result)
예제 #17
0
    def _joint_sample_n(self, n, seed=None):
        """Draw a joint sample from the prior over latents and observations.

    This sampler is specific to LocalLevel models and is faster than the
    generic LinearGaussianStateSpaceModel implementation.

    Args:
      n: `int` `Tensor` number of samples to draw.
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    Returns:
      latents: `float` `Tensor` of shape `concat([[n], self.batch_shape,
        [self.num_timesteps, self.latent_size]], axis=0)` representing samples
        of latent trajectories.
      observations: `float` `Tensor` of shape `concat([[n], self.batch_shape,
        [self.num_timesteps, self.observation_size]], axis=0)` representing
        samples of observed series generated from the sampled `latents`.
    """
        with tf.name_scope('joint_sample_n'):
            (initial_level_seed, level_jumps_seed,
             prior_observation_seed) = samplers.split_seed(
                 seed, n=3, salt='LocalLevelStateSpaceModel_joint_sample_n')

            if self.batch_shape.is_fully_defined():
                batch_shape = self.batch_shape
            else:
                batch_shape = self.batch_shape_tensor()
            sample_and_batch_shape = ps.cast(
                ps.concat([[n], batch_shape], axis=0), tf.int32)

            # Sample the initial timestep from the prior.  Since we want
            # this sample to have full batch shape (not just the batch shape
            # of the self.initial_state_prior object which might in general be
            # smaller), we augment the sample shape to include whatever
            # extra batch dimensions are required.
            initial_level = self.initial_state_prior.sample(
                linear_gaussian_ssm._augment_sample_shape(  # pylint: disable=protected-access
                    self.initial_state_prior, sample_and_batch_shape,
                    self.validate_args),
                seed=initial_level_seed)

            # Sample the latent random walk and observed noise, more efficiently than
            # the generic loop in `LinearGaussianStateSpaceModel`.
            level_jumps = self.level_scale[..., tf.newaxis] * samplers.normal(
                ps.concat([sample_and_batch_shape, [self.num_timesteps - 1]],
                          axis=0),
                dtype=self.dtype,
                seed=level_jumps_seed)
            prior_level_sample = tf.cumsum(tf.concat(
                [initial_level, level_jumps], axis=-1),
                                           axis=-1)
            prior_observation_sample = prior_level_sample + (  # Sample noise.
                self.observation_noise_scale[..., tf.newaxis] *
                samplers.normal(ps.shape(prior_level_sample),
                                dtype=self.dtype,
                                seed=prior_observation_seed))

            return (prior_level_sample[..., tf.newaxis],
                    prior_observation_sample[..., tf.newaxis])
예제 #18
0
  def __init__(
      self,
      input_size,
      output_size,
      # Weights
      init_kernel_fn=None,  # tfp.experimental.nn.initializers.glorot_uniform()
      init_bias_fn=None,    # tf.initializers.zeros()
      make_kernel_bias_fn=nn_util_lib.make_kernel_bias,
      dtype=tf.float32,
      batch_shape=(),
      # Misc
      activation_fn=None,
      name=None):
    """Constructs layer.

    Args:
      input_size: ...
      output_size: ...
      init_kernel_fn: ...
        Default value: `None` (i.e.,
        `tfp.experimental.nn.initializers.glorot_uniform()`).
      init_bias_fn: ...
        Default value: `None` (i.e., `tf.initializers.zeros()`).
      make_kernel_bias_fn: ...
        Default value: `tfp.experimental.nn.util.make_kernel_bias`.
      dtype: ...
        Default value: `tf.float32`.
      batch_shape: ...
        Default value: `()`.
      activation_fn: ...
        Default value: `None`.
      name: ...
        Default value: `None` (i.e., `'Affine'`).
    """
    batch_shape = tf.constant(
        [], dtype=tf.int32) if batch_shape is None else prefer_static.cast(
            prefer_static.reshape(batch_shape, shape=[-1]), tf.int32)
    batch_ndims = prefer_static.size(batch_shape)
    kernel_shape = prefer_static.concat([
        batch_shape, [input_size, output_size]], axis=0)
    bias_shape = prefer_static.concat([batch_shape, [output_size]], axis=0)
    apply_kernel_fn = lambda x, k: tf.matmul(
        x[..., tf.newaxis, :], k)[..., 0, :]  # pylint-disable=long-lambda
    kernel, bias = make_kernel_bias_fn(
        kernel_shape, bias_shape,
        init_kernel_fn, init_bias_fn,
        batch_ndims, batch_ndims,
        dtype)
    self._make_kernel_bias_fn = make_kernel_bias_fn  # For tracking.
    super(Affine, self).__init__(
        kernel=kernel,
        bias=bias,
        apply_kernel_fn=apply_kernel_fn,
        activation_fn=activation_fn,
        dtype=dtype,
        name=name)
예제 #19
0
  def _log_prob(self, value):
    """Log probability of multivariate normal.

    Costs a log_abs_determinant, matvec, and a reduce_sum over a squared
    (batch of) vector(s)

    Args:
      value: Floating point `Tensor`.

    Returns:
      Floating point `Tensor` with batch shape.
    """
    dim = self.precision_factor.domain_dimension_tensor()
    return (ps.cast(-0.5 * np.log(2 * np.pi), self.dtype) *
            ps.cast(dim, self.dtype) +
            # Notice the sign on the LinearOperator.log_abs_determinant is
            # positive, since it is precision_factor not scale.
            self._precision_factor.log_abs_determinant() +
            self._log_prob_unnormalized(value))
 def _forward_log_det_jacobian(self, x):
     # This code is similar to tf.math.log_softmax but different because we have
     # an implicit zero column to handle. I.e., instead of:
     #   reduce_sum(logits - reduce_sum(exp(logits), dim))
     # we must do:
     #   log_normalization = 1 + reduce_sum(exp(logits))
     #   -log_normalization + reduce_sum(logits - log_normalization)
     np1 = prefer_static.cast(1 + prefer_static.shape(x)[-1], dtype=x.dtype)
     return (0.5 * prefer_static.log(np1) + tf.reduce_sum(x, axis=-1) -
             np1 * tf.math.softplus(tf.reduce_logsumexp(x, axis=-1)))
예제 #21
0
 def _reshape_part(part, event_shape):
     part = tf.cast(part, self.dtype)
     new_shape = ps.concat(
         [
             ps.shape(part)[:ps.size(ps.shape(part)) -
                            ps.size(event_shape)], [-1]
         ],
         axis=-1,
     )
     return tf.reshape(part, ps.cast(new_shape, tf.int32))
예제 #22
0
 def _forward(self, x):
   ndims = ps.rank(x)
   indices = ps.reshape(ps.add(self.axis, ndims), shape=[-1, 1])
   return tf.pad(
       x,
       paddings=ps.tensor_scatter_nd_update(
           ps.zeros([ndims, 2], dtype=tf.int32),
           indices, self.paddings),
       mode=self.mode,
       constant_values=ps.cast(self.constant_values, dtype=x.dtype))
예제 #23
0
    def __init__(self,
                 samples,
                 event_ndims=0,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='Empirical'):
        """Initialize `Empirical` distributions.

    Args:
      samples: Numeric `Tensor` of shape [B1, ..., Bk, S, E1, ..., En]`,
        `k, n >= 0`. Samples or batches of samples on which the distribution
        is based. The first `k` dimensions index into a batch of independent
        distributions. Length of `S` dimension determines number of samples
        in each multiset. The last `n` dimension represents samples for each
        distribution. n is specified by argument event_ndims.
      event_ndims: Python `int32`, default `0`. number of dimensions for each
        event. When `0` this distribution has scalar samples. When `1` this
        distribution has vector-like samples.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value `NaN` to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if the rank of `samples` is statically known and less than
        event_ndims + 1.
    """

        parameters = dict(locals())
        with tf.name_scope(name):
            self._samples = tensor_util.convert_nonref_to_tensor(samples)
            dtype = dtype_util.common_dtype([self._samples],
                                            dtype_hint=self._samples.dtype)
            self._event_ndims = event_ndims

            # Note: this tf.rank call affects the graph, but is ok in `__init__`
            # because we don't expect shapes (or ranks) to be runtime-variable, nor
            # ever need to differentiate with respect to them.
            samples_rank = prefer_static.rank(self._samples)
            self._samples_axis = prefer_static.cast(
                samples_rank - self._event_ndims - 1, tf.int32)

            super(Empirical,
                  self).__init__(dtype=dtype,
                                 reparameterization_type=reparameterization.
                                 FULLY_REPARAMETERIZED,
                                 validate_args=validate_args,
                                 allow_nan_stats=allow_nan_stats,
                                 parameters=parameters,
                                 name=name)
예제 #24
0
def _initialize(shape, dtype, batch_ndims, scale, mode, distribution,
                seed=None):
  """Samples a random `Tensor` per specified args."""
  if not dtype_util.is_floating(dtype):
    raise TypeError('Argument `dtype` must be float type (saw: "{}").'.format(
        dtype))
  shape = prefer_static.reshape(shape, shape=[-1])  # Ensure shape is vector.
  fan_in, fan_out = _compute_fans_from_shape(shape, batch_ndims)
  fans = _summarize_fans(fan_in, fan_out, mode, dtype)
  scale = prefer_static.cast(scale, dtype)
  return _sample_distribution(shape, scale / fans, distribution, seed, dtype)
예제 #25
0
    def sample(self, sample_shape=(), seed=None, name=None):
        with tf.name_scope(name or 'sample'):
            # Grab the required number of values from the provided tensors.
            sample_shape = dist_util.expand_to_vector(sample_shape)
            n = ps.cast(ps.reduce_prod(sample_shape), dtype=tf.int32)

            # Check that we're not trying to draw too many samples.
            assertions = []
            will_overflow_ = tf.get_static_value(n > self.max_num_samples)
            if will_overflow_:
                raise ValueError(
                    'Trying to draw {} samples from a '
                    '`DeterministicEmpirical` instance for which only {} '
                    'samples were provided.'.format(
                        tf.get_static_value(n),
                        tf.get_static_value(self.max_num_samples)))
            elif (will_overflow_ is None  # Couldn't determine statically.
                  and self.validate_args):
                assertions.append(
                    tf.debugging.assert_less_equal(
                        n,
                        self.max_num_samples,
                        message='Number of samples to draw '
                        'from a `DeterministicEmpirical` instance must not exceed the '
                        'number provided at construction.'))

            # Extract the appropriate number of sampled values.
            with tf.control_dependencies(assertions):
                sampled = tf.nest.map_structure(lambda x: x[:n, ...],
                                                self.values_with_sample_dim)

            # Reshape the values to the appropriate sample shape.
            return tf.nest.map_structure(
                lambda x: tf.reshape(
                    x,  # pylint: disable=g-long-lambda
                    ps.concat([
                        ps.cast(sample_shape, tf.int32),
                        ps.cast(ps.shape(x)[1:], tf.int32)
                    ],
                              axis=0)),
                sampled)
예제 #26
0
  def _get_reinterpreted_batch_ndims(self,
                                     distribution_batch_shape_tensor=None):
    if self._static_reinterpreted_batch_ndims is not None:
      return self._static_reinterpreted_batch_ndims
    if self._reinterpreted_batch_ndims is not None:
      return tf.convert_to_tensor(self._reinterpreted_batch_ndims)

    if distribution_batch_shape_tensor is None:
      distribution_batch_shape_tensor = self.distribution.batch_shape_tensor()
    return ps.cast(
        ps.maximum(0, ps.size(distribution_batch_shape_tensor) - 1),
        np.int32)
예제 #27
0
    def even_odd_swaps(num_replica,
                       batch_shape=(),
                       step_count=None,
                       seed=None):
        """Make deterministic even_odd one time swaps."""
        if step_count is None:
            raise ValueError('`step_count` must be supplied. Found `None`.')
        del seed  # Unused for this function.
        with tf.name_scope(name or 'even_odd_swaps'):
            # Period is 1 / frequency, and we want period = Inf if frequency = 0.
            # safe_swap_period is the correct swap period in case swap_frequency > 0.
            # If swap_frequency == 0, safe_swap_period is set to 1 (to avoid integer
            # div by zero below). We will hard-set this case to "null swap."
            swap_freq = tf.convert_to_tensor(swap_frequency,
                                             name='swap_frequency')
            safe_swap_period = tf.cast(
                tf.where(swap_freq > 0,
                         tf.math.ceil(tf.math.reciprocal_no_nan(swap_freq)),
                         1),
                # Although period = 1 / frequency may have roundoff error, and result
                # in a period different than what the user intended, the
                # user will end up with a single integer period, and thus well defined
                # deterministic swaps.
                tf.int32,
            )

            # u selects parity.  E.g.,
            #  u==False ==> [1, 0, 3, 2, 4] even parity swaps
            #  u==True ==>  [0, 2, 1, 4, 3] odd parity swaps
            # If there are 2 replicas, then the "True" swaps are null
            # swaps...which would contradict the user provided `swap_frequency`.
            # So special case num_replica==2, forcing u==False in this case.
            u_shape = ps.concat(
                (ps.ones(1, dtype=tf.int32), ps.cast(batch_shape, tf.int32)),
                axis=0)
            u = tf.fill(u_shape,
                        tf.cast((step_count // safe_swap_period) % 2, tf.bool))
            u = tf.where(num_replica > 2, u, False)

            x = bu.left_justified_expand_dims_to(tf.range(num_replica,
                                                          dtype=tf.int64),
                                                 rank=ps.size(u_shape))
            y = tf.where(tf.equal(x % 2, tf.cast(u, dtype=tf.int64)), x + 1,
                         x - 1)
            y = tf.clip_by_value(y, 0, num_replica - 1)
            # TODO(b/142689785): Consider using tf.cond and returning an empty list
            # then in REMC consider using a tf.cond for short-circuiting.
            return tf.where(
                (tf.cast(step_count % safe_swap_period, tf.bool)
                 | tf.math.equal(swap_freq, 0)),
                x,  # Don't swap
                y,  # Swap
            )
예제 #28
0
def _make_input_and_kernel(
    make_input, input_batch_shape, input_shape, kernel_batch_shape,
    filter_shape, channels_out, dtype):
  total_input_shape = ps.concat([input_batch_shape, input_shape], axis=0)
  total_kernel_shape = ps.concat(
      [kernel_batch_shape, [filter_shape[0] * filter_shape[1] * input_shape[-1],
                            channels_out]], axis=0)
  # Use integers for numerical stability.
  sample_fn = lambda s: make_input(tf.cast(  # pylint: disable=g-long-lambda
      tf.random.uniform(
          ps.cast(s, tf.int32), minval=-10, maxval=10, dtype=tf.int32),
      dtype=dtype))
  return sample_fn(total_input_shape), sample_fn(total_kernel_shape)
예제 #29
0
def _scatter_nd_batch(indices, updates, shape, batch_dims=0):
  """A partial implementation of `scatter_nd` supporting `batch_dims`."""

  # `tf.scatter_nd` does not support a `batch_dims` argument.
  # Instead we use the gradient of `tf.gather_nd`.
  # From a purely mathematical perspective this works because
  # (if `tf.scatter_nd` supported `batch_dims`)
  # `gather_nd` and `scatter_nd` (with matching `indices`) are
  # adjoint linear operators and
  # the gradient w.r.t `x` of `dot(y, A(x))` is `adjoint(A)(y)`.
  #
  # Another perspective: back propagating through a "neural" network
  # containing a gather operation carries derivatives backwards through the
  # network, accumulating the derivatives in the locations that
  # were gathered from, ie. they are scattered.
  # If the network multiplies each gathered element by
  # some quantity, then the backwardly propagating derivatives are scaled
  # by this quantity before being scattered.
  # Combining this with the fact that`GradientTape.gradient`
  # starts back-propagation with derivatives equal to `1`, this allows us
  # to use the multipliers to determine the quantities scattered.
  #
  # However, derivatives are only supported for floating point types
  # so we 'tunnel' our types through the `float64` type.
  # So the implmentation is "partial" in the sense that it supports
  # data that can be losslessly converted to `tf.float64` and back.
  dtype = updates.dtype
  internal_dtype = tf.float64
  multipliers = ps.cast(updates, internal_dtype)
  with tf.GradientTape() as tape:
    zeros = tf.zeros(shape, dtype=internal_dtype)
    tape.watch(zeros)
    weighted_gathered = multipliers * tf.gather_nd(
        zeros,
        indices,
        batch_dims=batch_dims)
  grad = tape.gradient(weighted_gathered, zeros)
  return ps.cast(grad, dtype=dtype)
def make_rwmh_kernel_fn(target_log_prob_fn, init_state, scalings):
    """Generate a Random Walk MH kernel."""
    with tf.name_scope('make_rwmh_kernel_fn'):
        state_std = [
            tf.math.reduce_std(x, axis=0, keepdims=True) for x in init_state
        ]
        step_size = [
            s * ps.cast(  # pylint: disable=g-complex-comprehension
                bu.left_justified_expand_dims_like(scalings, s), s.dtype)
            for s in state_std
        ]
        return random_walk_metropolis.RandomWalkMetropolis(
            target_log_prob_fn,
            new_state_fn=random_walk_metropolis.random_walk_normal_fn(
                scale=step_size))