示例#1
0
            # W is the within sequence variance (the mean of the chain variances)
            sum_of_chain_squared_residuals = tf.reduce_sum(
                chain_variances.sum_squared_residuals,
                axis=ps.range(chain_ndims))
            w = sum_of_chain_squared_residuals / (m * (n - 1))

            # the `true_variance_estimate` is denoted as sigma^2_+ in the 1998 paper
            true_variance_estimate = ((n - 1) / n) * w + b_div_n
            return (
                (m + 1.) / m) * true_variance_estimate / w - (n - 1.) / (m * n)

        return tf.nest.map_structure(_finalize_for_one_state,
                                     self.independent_chain_ndims,
                                     self.chain_variances,
                                     check_types=False)

    def __repr__(self):
        return (
            'RunningPotentialScaleReduction(\n'
            f'    chain_variances={self.chain_variances!r},\n'
            f'    independent_chain_ndims={self.independent_chain_ndims!r})')


if JAX_MODE:
    from jax import tree_util  # pylint: disable=g-import-not-at-top
    tree_util.register_pytree_node_class(RunningCentralMoments)
    tree_util.register_pytree_node_class(RunningCovariance)
    tree_util.register_pytree_node_class(RunningVariance)
    tree_util.register_pytree_node_class(RunningMean)
    tree_util.register_pytree_node_class(RunningPotentialScaleReduction)
示例#2
0
        self.grads = grads

    def tree_flatten(self):
        return (self.grads, ), ()

    @classmethod
    def tree_unflatten(cls, _, xs):
        return cls(*xs)

    def __repr__(self):
        return f'_DummyGrads({self.grads})'


if JAX_MODE:
    from jax import tree_util  # pylint: disable=g-import-not-at-top
    tree_util.register_pytree_node_class(_DummyGrads)


def make_sharded_log_prob_parts(log_prob_parts_fn, axis_names):
    """Constructs a log prob parts function that all-reduces over terms.

  Given a log_prob_parts function, this function will return a new one that
  includes all-reduce sums over terms according to the `is_sharded` property. It
  will also add all-reduce sums for the gradient of sharded terms w.r.t.
  unsharded terms.

  Args:
    log_prob_parts_fn: a callable that takes in a structured value and returns a
      structure of log densities for each of the terms, that when summed returns
      a locally correct log-density.
    axis_names: a structure of values that matches the input and output
示例#3
0
 def __init_subclass__(cls, **kw):
     tree_util.register_pytree_node_class(cls)
     super().__init_subclass__(**kw)