def testNonEmptyConstantTensor(self):
     x = tf.zeros((2, 3, 4))
     shape = distribution_util.prefer_static_shape(x)
     self.assertIsInstance(shape, np.ndarray)
     self.assertAllEqual([2, 3, 4], shape)
示例#2
0
    def test_increment_log_prob(self):

        root = tfd.JointDistributionCoroutine.Root
        prior_mean = 3.
        x_size = 100

        def custom_ll(w, x):
            return tf.reduce_sum(tfd.Normal(w, 1.).log_prob(x))

        def ulp_grad(w, x):
            @joint_density_coroutine.JointDensityCoroutine
            def sharded_model():
                w = yield root(tfd.Normal(prior_mean, 1.))
                yield root(
                    sharded.Sharded(increment_log_prob.IncrementLogProb(
                        custom_ll(w, x)),
                                    shard_axis_name=self.axis_name))

            def ulp_fn(w):
                zeros = tf.zeros([x_size, 0])
                return sharded_model.unnormalized_log_prob(w, zeros)

            ulp, g = tfp.math.value_and_gradient(ulp_fn, (w, ))
            return ulp, g

        def true_ulp_grad(w, x):
            @joint_density_coroutine.JointDensityCoroutine
            def model():
                w = yield root(tfd.Normal(prior_mean, 1.))
                yield root(increment_log_prob.IncrementLogProb(custom_ll(w,
                                                                         x)))

            def ulp_fn(w):
                zeros = tf.zeros([x_size, 0])
                return model.unnormalized_log_prob(w, zeros)

            ulp, g = tfp.math.value_and_gradient(ulp_fn, (w, ))
            return ulp, g

        def test_w_x(w, x):
            sharded_x = self.shard_values(
                tf.reshape(x, [test_lib.NUM_DEVICES, -1]))

            lp, g = self.evaluate(
                self.per_replica_to_tensor(
                    self.strategy_run(ulp_grad, (
                        w,
                        sharded_x,
                    ),
                                      in_axes=(None, 0))))
            true_lp, true_g = self.evaluate(true_ulp_grad(w, x))

            self.assertAllClose(true_lp, lp[0])
            self.assertAllClose(true_g[0], g[0][0])

        w = tf.constant(4.)
        zeros = tf.zeros([x_size])
        test_w_x(w, zeros)
        random_x = self.evaluate(
            tfd.Normal(loc=tf.zeros([x_size]),
                       scale=tf.ones([x_size])).sample(seed=self.key))
        test_w_x(w, random_x)
示例#3
0
 def run(key):
     return tfp_dist.Sharded(
         tfd.Independent(tfd.Normal(tf.zeros(1), tf.ones(1)), 1),
         shard_axis_name=self.axis_name).sample(seed=key)
 def _mean(self):
     return tf.zeros(self.batch_shape_tensor())
示例#5
0
    def _loop_tree_doubling(self, step_size, momentum_state_memory,
                            current_step_meta_info, iter_, initial_step_state,
                            initial_step_metastate, seed):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            (direction_seed, subtree_seed, acceptance_seed,
             next_seed) = samplers.split_seed(seed, n=4)
            batch_shape = ps.shape(current_step_meta_info.init_energy)
            direction = tf.cast(samplers.uniform(shape=batch_shape,
                                                 minval=0,
                                                 maxval=2,
                                                 dtype=tf.int32,
                                                 seed=direction_seed),
                                dtype=tf.bool)

            tree_start_states = tf.nest.map_structure(
                lambda v: bu.where_left_justified_mask(direction, v[1], v[0]),
                initial_step_state)

            directions_expanded = [
                bu.left_justified_expand_dims_like(direction, state)
                for state in tree_start_states.state
            ]

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(d, ss, -ss)
                    for d, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state, tree_final_states, final_not_divergence,
                continue_tree_final, energy_diff_tree_sum,
                momentum_subtree_cumsum, leapfrogs_taken
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                current_step_meta_info,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory,
                seed=subtree_seed)

            last_candidate_state = initial_step_metastate.candidate_state

            energy_diff_sum = (energy_diff_tree_sum +
                               initial_step_metastate.energy_diff_sum)
            if MULTINOMIAL_SAMPLE:
                tree_weight = tf.where(
                    continue_tree_final, candidate_tree_state.weight,
                    tf.constant(-np.inf,
                                dtype=candidate_tree_state.weight.dtype))
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                tree_weight = tf.where(continue_tree_final,
                                       candidate_tree_state.weight,
                                       tf.zeros([], dtype=TREE_COUNT_DTYPE))
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-samplers.uniform(shape=batch_shape,
                                                dtype=log_accept_thresh.dtype,
                                                seed=acceptance_seed))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    bu.where_left_justified_mask(choose_new_state, s0, s1)
                    for s0, s1 in zip(candidate_tree_state.state,
                                      last_candidate_state.state)
                ],
                target=bu.where_left_justified_mask(
                    choose_new_state, candidate_tree_state.target,
                    last_candidate_state.target),
                target_grad_parts=[
                    bu.where_left_justified_mask(choose_new_state, grad0,
                                                 grad1)
                    for grad0, grad1 in zip(
                        candidate_tree_state.target_grad_parts,
                        last_candidate_state.target_grad_parts)
                ],
                energy=bu.where_left_justified_mask(
                    choose_new_state, candidate_tree_state.energy,
                    last_candidate_state.energy),
                weight=weight_sum)

            for new_candidate_state_temp, old_candidate_state_temp in zip(
                    new_candidate_state.state, last_candidate_state.state):
                tensorshape_util.set_shape(new_candidate_state_temp,
                                           old_candidate_state_temp.shape)

            for new_candidate_grad_temp, old_candidate_grad_temp in zip(
                    new_candidate_state.target_grad_parts,
                    last_candidate_state.target_grad_parts):
                tensorshape_util.set_shape(new_candidate_grad_temp,
                                           old_candidate_grad_temp.shape)

            # Update left right information of the trajectory, and check trajectory
            # level U turn
            tree_otherend_states = tf.nest.map_structure(
                lambda v: bu.where_left_justified_mask(direction, v[0], v[1]),
                initial_step_state)

            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    tf.stack(
                        [  # pylint: disable=g-complex-comprehension
                            bu.where_left_justified_mask(
                                direction, right, left),
                            bu.where_left_justified_mask(
                                direction, left, right),
                        ],
                        axis=0) for left, right in zip(
                            tf.nest.flatten(tree_final_states),
                            tf.nest.flatten(tree_otherend_states))
                ])

            momentum_tree_cumsum = []
            for p0, p1 in zip(initial_step_metastate.momentum_sum,
                              momentum_subtree_cumsum):
                momentum_part_temp = p0 + p1
                tensorshape_util.set_shape(momentum_part_temp, p0.shape)
                momentum_tree_cumsum.append(momentum_part_temp)

            for new_state_temp, old_state_temp in zip(
                    tf.nest.flatten(new_step_state),
                    tf.nest.flatten(initial_step_state)):
                tensorshape_util.set_shape(new_state_temp,
                                           old_state_temp.shape)

            if GENERALIZED_UTURN:
                state_diff = momentum_tree_cumsum
            else:
                state_diff = [s[1] - s[0] for s in new_step_state.state]

            no_u_turns_trajectory = has_not_u_turn(
                state_diff, [m[0] for m in new_step_state.momentum],
                [m[1] for m in new_step_state.momentum],
                log_prob_rank=ps.rank_from_shape(batch_shape),
                shard_axis_names=self.experimental_shard_axis_names)

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                momentum_sum=momentum_tree_cumsum,
                energy_diff_sum=energy_diff_sum,
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, next_seed, new_step_state, new_step_metastate
示例#6
0
 def null_input(self):
   return tf.zeros([1, self._num_tokens], dtype=tf.float32)
示例#7
0
def sample_annealed_importance_chain(num_steps,
                                     proposal_log_prob_fn,
                                     target_log_prob_fn,
                                     current_state,
                                     make_kernel_fn,
                                     parallel_iterations=10,
                                     seed=None,
                                     name=None):
    """Runs annealed importance sampling (AIS) to estimate normalizing constants.

  This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
  to sample from a series of distributions that slowly interpolates between
  an initial 'proposal' distribution:

  `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)`

  and the target distribution:

  `exp(target_log_prob_fn(x) - target_log_normalizer)`,

  accumulating importance weights along the way. The product of these
  importance weights gives an unbiased estimate of the ratio of the
  normalizing constants of the initial distribution and the target
  distribution:

  `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`.

  Note: When running in graph mode, `proposal_log_prob_fn` and
  `target_log_prob_fn` are called exactly three times (although this may be
  reduced to two times in the future).

  Args:
    num_steps: Integer number of Markov chain updates to run. More
      iterations means more expense, but smoother annealing between q
      and p, which in turn means exponentially lower variance for the
      normalizing constant estimator.
    proposal_log_prob_fn: Python callable that returns the log density of the
      initial distribution.
    target_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the target distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
      `sample_annealed_importance_chain` creates a new `target_log_prob_fn`
      which is an interpolation between the supplied `target_log_prob_fn` and
      `proposal_log_prob_fn`; it is this interpolated function which is used as
      an argument to `make_kernel_fn`.
    parallel_iterations: The number of iterations allowed to run in parallel.
      It must be a positive integer. See `tf.while_loop` for more details.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'sample_annealed_importance_chain').

  Returns:
    next_state: `Tensor` or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at the final iteration. Has same shape as
      input `current_state`.
    ais_weights: Tensor with the estimated weight(s). Has shape matching
      `target_log_prob_fn(current_state)`.
    kernel_results: `collections.namedtuple` of internal calculations used to
      advance the chain.

  #### Examples

  ##### Estimate the normalizing constant of a log-gamma distribution.

  ```python
  tfd = tfp.distributions

  # Run 100 AIS chains in parallel
  num_chains = 100
  dims = 20
  dtype = np.float32

  proposal = tfd.MultivariateNormalDiag(
     loc=tf.zeros([dims], dtype=dtype))

  target = tfd.TransformedDistribution(
    distribution=tfd.Sample(
        tfd.Gamma(concentration=dtype(2), rate=dtype(3)),
        sample_shape=[dims])
    bijector=tfp.bijectors.Invert(tfp.bijectors.Exp()))

  chains_state, ais_weights, kernels_results = (
      tfp.mcmc.sample_annealed_importance_chain(
          num_steps=1000,
          proposal_log_prob_fn=proposal.log_prob,
          target_log_prob_fn=target.log_prob,
          current_state=proposal.sample(num_chains),
          make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn=tlp_fn,
            step_size=0.2,
            num_leapfrog_steps=2)))

  log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)
                              - np.log(num_chains))
  log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.)
  ```

  ##### Estimate marginal likelihood of a Bayesian regression model.

  ```python
  tfd = tfp.distributions

  def make_prior(dims, dtype):
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros(dims, dtype))

  def make_likelihood(weights, x):
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(weights, x, axes=[[0], [-1]]))

  # Run 100 AIS chains in parallel
  num_chains = 100
  dims = 10
  dtype = np.float32

  # Make training data.
  x = np.random.randn(num_chains, dims).astype(dtype)
  true_weights = np.random.randn(dims).astype(dtype)
  y = np.dot(x, true_weights) + np.random.randn(num_chains)

  # Setup model.
  prior = make_prior(dims, dtype)
  def target_log_prob_fn(weights):
    return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)

  proposal = tfd.MultivariateNormalDiag(
      loc=tf.zeros(dims, dtype))

  weight_samples, ais_weights, kernel_results = (
      tfp.mcmc.sample_annealed_importance_chain(
        num_steps=1000,
        proposal_log_prob_fn=proposal.log_prob,
        target_log_prob_fn=target_log_prob_fn
        current_state=tf.zeros([num_chains, dims], dtype),
        make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
          target_log_prob_fn=tlp_fn,
          step_size=0.1,
          num_leapfrog_steps=2)))
  log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)
                             - np.log(num_chains))
  ```

  """
    is_seeded = seed is not None
    seed = samplers.sanitize_seed(seed, salt='mcmc.sample_ais_chain')

    with tf.name_scope(name or 'sample_annealed_importance_chain'):
        num_steps = tf.convert_to_tensor(value=num_steps,
                                         dtype=tf.int32,
                                         name='num_steps')
        if mcmc_util.is_list_like(current_state):
            current_state = [
                tf.convert_to_tensor(s, name='current_state')
                for s in current_state
            ]
        else:
            current_state = tf.convert_to_tensor(value=current_state,
                                                 name='current_state')

        def _make_convex_combined_log_prob_fn(iter_):
            def _fn(*args):
                p = tf.identity(proposal_log_prob_fn(*args),
                                name='proposal_log_prob')
                t = tf.identity(target_log_prob_fn(*args),
                                name='target_log_prob')
                dtype = dtype_util.base_dtype(p.dtype)
                beta = tf.cast(iter_ + 1, dtype) / tf.cast(num_steps, dtype)
                return tf.identity(beta * t + (1. - beta) * p,
                                   name='convex_combined_log_prob')

            return _fn

        def _loop_body(iter_, seed, ais_weights, current_state,
                       kernel_results):
            """Closure which implements `tf.while_loop` body."""
            iter_seed, next_seed = samplers.split_seed(
                seed,
                salt='ais_chain.seeded_one_step') if is_seeded else (seed,
                                                                     seed)

            x = (current_state
                 if mcmc_util.is_list_like(current_state) else [current_state])
            proposal_log_prob = proposal_log_prob_fn(*x)
            target_log_prob = target_log_prob_fn(*x)
            ais_weights += ((target_log_prob - proposal_log_prob) /
                            tf.cast(num_steps, ais_weights.dtype))
            kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_))
            # TODO(b/147676843): Should we warn if the kernel is not calibrated?
            one_step_kwargs = dict(seed=iter_seed) if is_seeded else {}
            next_state, inner_results = kernel.one_step(
                current_state, kernel_results.inner_results, **one_step_kwargs)
            kernel_results = AISResults(
                proposal_log_prob=proposal_log_prob,
                target_log_prob=target_log_prob,
                inner_results=inner_results,
            )
            return [
                iter_ + 1, next_seed, ais_weights, next_state, kernel_results
            ]

        def _bootstrap_results(init_state):
            """Creates first version of `previous_kernel_results`."""
            kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_=0))
            inner_results = kernel.bootstrap_results(init_state)
            mh_results = _find_inner_mh_results(inner_results)

            convex_combined_log_prob = mh_results.accepted_results.target_log_prob
            dtype = dtype_util.as_numpy_dtype(convex_combined_log_prob.dtype)
            shape = tf.shape(convex_combined_log_prob)
            proposal_log_prob = tf.fill(shape,
                                        dtype(np.nan),
                                        name='bootstrap_proposal_log_prob')
            target_log_prob = tf.fill(shape,
                                      dtype(np.nan),
                                      name='target_target_log_prob')

            return AISResults(
                proposal_log_prob=proposal_log_prob,
                target_log_prob=target_log_prob,
                inner_results=inner_results,
            )

        previous_kernel_results = _bootstrap_results(current_state)
        inner_results = previous_kernel_results.inner_results
        mh_results = _find_inner_mh_results(inner_results)

        ais_weights = tf.zeros(
            shape=tf.broadcast_dynamic_shape(
                tf.shape(mh_results.proposed_results.target_log_prob),
                tf.shape(mh_results.accepted_results.target_log_prob)),
            dtype=mh_results.proposed_results.target_log_prob.dtype)

        [_, _, ais_weights, current_state, kernel_results] = tf.while_loop(
            cond=lambda iter_, *args: iter_ < num_steps,
            body=_loop_body,
            loop_vars=[
                np.int32(0),  # iter_
                seed,
                ais_weights,
                current_state,
                previous_kernel_results,
            ],
            parallel_iterations=parallel_iterations)

        return [current_state, ais_weights, kernel_results]
示例#8
0
 def _forward_log_det_jacobian(self, x):
     return tf.zeros([], dtype=x.dtype)
def sample_sequential_monte_carlo(
        prior_log_prob_fn,
        likelihood_log_prob_fn,
        current_state,
        min_num_steps=2,
        max_num_steps=25,
        max_stage=100,
        make_kernel_fn=make_rwmh_kernel_fn,
        tuning_fn=simple_heuristic_tuning,
        make_tempered_target_log_prob_fn=default_make_tempered_target_log_prob_fn,
        ess_threshold_ratio=0.5,
        parallel_iterations=10,
        seed=None,
        name=None):
    """Runs Sequential Monte Carlo to sample from the posterior distribution.

  This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
  to sample from a series of distributions that slowly interpolates between
  an initial 'prior' distribution:

    `exp(prior_log_prob_fn(x))`

  and the target 'posterior' distribution:

    `exp(prior_log_prob_fn(x) + target_log_prob_fn(x))`,

  by mutating a collection of MC samples (i.e., particles). The approach is also
  known as Particle Filter in some literature. The current implemenetation is
  largely based on Del Moral et al [1], which adapts the tempering sequence
  adaptively (base on the effective sample size) and the scaling of the mutation
  kernel (base on the sample covariance of the particles) at each stage.

  Args:
    prior_log_prob_fn: Python callable that returns the log density of the
      prior distribution.
    likelihood_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the likelihood distribution.
    current_state: Nested structure of `Tensor`s, each of shape
      `concat([[num_particles, b1, ..., bN], latent_part_event_shape])`, where
      `b1, ..., bN` are optional batch dimensions. Each batch represents an
      independent SMC run.
    min_num_steps: The minimal number of kernel transition steps in one mutation
      of the MC samples.
    max_num_steps: The maximum number of kernel transition steps in one mutation
      of the MC samples. Note that the actual number of steps in one mutation is
      tuned during sampling and likely lower than the max_num_step.
    max_stage: Integer number of the stage for increasing the temperature
      from 0 to 1.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
      `sample_sequential_monte_carlo` creates a new `target_log_prob_fn`
      which is an interpolation between the supplied `target_log_prob_fn` and
      `proposal_log_prob_fn`; it is this interpolated function which is used as
      an argument to `make_kernel_fn`.
    tuning_fn: Python `callable` which takes the number of steps, the log
      scaling, and the log acceptance ratio from the last mutation and output
      the number of steps and log scaling for the next mutation.
    make_tempered_target_log_prob_fn: Python `callable` that takes the
      `prior_log_prob_fn`, `likelihood_log_prob_fn`, and `inverse_temperatures`
      and creates a `target_log_prob_fn` `callable` that pass to
      `make_kernel_fn`.
    ess_threshold_ratio: Target ratio for effective sample size.
    parallel_iterations: The number of iterations allowed to run in parallel.
        It must be a positive integer. See `tf.while_loop` for more details.
    seed: Python integer or TFP seedstream to seed the random number generator.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'sample_sequential_monte_carlo').

  Returns:
    n_stage: Number of the mutation stage SMC ran.
    final_state: `Tensor` or Python `list` of `Tensor`s representing the
      final state(s) of the Markov chain(s). The output are the posterior
      samples.
    final_kernel_results: `collections.namedtuple` of internal calculations used
      to advance the chain.

  #### References

  [1] Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. An adaptive sequential
      Monte Carlo method for approximate Bayesian computation.
      _Statistics and Computing_, 22.5(1009-1020), 2012.

  """

    with tf.name_scope(name or 'sample_sequential_monte_carlo'):
        seed_stream = SeedStream(seed, salt='smc_seed')

        unwrap_state_list = not tf.nest.is_nested(current_state)
        if unwrap_state_list:
            current_state = [current_state]
        current_state = [
            tf.convert_to_tensor(s, dtype_hint=tf.float32)
            for s in current_state
        ]

        # Initial preprocessing at Stage 0
        likelihood_log_prob = likelihood_log_prob_fn(*current_state)

        likelihood_rank = ps.rank(likelihood_log_prob)
        dimension = ps.reduce_sum([
            ps.reduce_prod(ps.shape(x)[likelihood_rank:])
            for x in current_state
        ])

        # We infer the particle shapes from the resulting likelihood:
        # [num_particles, b1, ..., bN]
        particle_shape = ps.shape(likelihood_log_prob)
        num_particles, batch_shape = particle_shape[0], particle_shape[1:]
        effective_sample_size_threshold = tf.cast(
            num_particles * ess_threshold_ratio, tf.int32)

        # TODO(b/152412213): Revisit this default parameter.
        # Default to the optimal scaling of a random walk kernel for a d-dimensional
        # normal distributed targets: 2.38 ** 2 / d.
        # For more detail see:
        # Roberts GO, Gelman A, Gilks WR. Weak convergence and optimal scaling of
        # random walk Metropolis algorithms. _The annals of applied probability_.
        # 1997;7(1):110-20.
        scale_start = (tf.constant(2.38**2, dtype=likelihood_log_prob.dtype) /
                       tf.constant(dimension, dtype=likelihood_log_prob.dtype))

        inverse_temperature = tf.zeros(batch_shape,
                                       dtype=likelihood_log_prob.dtype)
        scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(
            scale_start, 1.)
        kernel = make_kernel_fn(make_tempered_target_log_prob_fn(
            prior_log_prob_fn, likelihood_log_prob_fn, inverse_temperature),
                                current_state,
                                scalings,
                                seed=seed_stream)
        pkr = kernel.bootstrap_results(current_state)
        _, kernel_target_log_prob = gather_mh_like_result(pkr)

        particle_info = ParticleInfo(
            log_accept_prob=ps.zeros_like(likelihood_log_prob),
            log_scalings=tf.math.log(scalings),
            tempered_log_prob=kernel_target_log_prob,
            likelihood_log_prob=likelihood_log_prob,
        )

        current_pkr = SMCResults(
            num_steps=tf.convert_to_tensor(max_num_steps,
                                           dtype=tf.int32,
                                           name='num_steps'),
            inverse_temperature=inverse_temperature,
            log_marginal_likelihood=tf.zeros_like(inverse_temperature),
            particle_info=particle_info)

        def update_weights_temperature(inverse_temperature,
                                       likelihood_log_prob):
            """Calculate the next inverse temperature and update weights."""
            likelihood_diff = likelihood_log_prob - tf.reduce_max(
                likelihood_log_prob, axis=0)

            def _body_fn(new_beta, upper_beta, lower_beta, eff_size,
                         log_weights):
                """One iteration of the temperature and weight update."""
                new_beta = (lower_beta + upper_beta) / 2.0
                log_weights = (new_beta -
                               inverse_temperature) * likelihood_diff
                log_weights_norm = tf.math.log_softmax(log_weights, axis=0)
                eff_size = tf.cast(
                    tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm,
                                                     axis=0)), tf.int32)
                upper_beta = tf.where(
                    eff_size < effective_sample_size_threshold, new_beta,
                    upper_beta)
                lower_beta = tf.where(
                    eff_size < effective_sample_size_threshold, lower_beta,
                    new_beta)
                return new_beta, upper_beta, lower_beta, eff_size, log_weights

            def _cond_fn(new_beta, upper_beta, lower_beta, eff_size, *_):  # pylint: disable=unused-argument
                # TODO(junpenglao): revisit threshold below to be dtype specific.
                threshold = 1e-6
                return (tf.math.reduce_any(upper_beta - lower_beta > threshold)
                        & tf.math.reduce_any(
                            eff_size != effective_sample_size_threshold))

            (new_beta, upper_beta, lower_beta, eff_size,
             log_weights) = tf.while_loop(  # pylint: disable=unused-variable
                 cond=_cond_fn,
                 body=_body_fn,
                 loop_vars=(tf.zeros_like(inverse_temperature),
                            tf.fill(ps.shape(inverse_temperature),
                                    tf.constant(2, inverse_temperature.dtype)),
                            inverse_temperature,
                            tf.zeros_like(inverse_temperature, dtype=tf.int32),
                            tf.zeros_like(likelihood_diff)),
                 parallel_iterations=parallel_iterations)

            log_weights = tf.where(new_beta < 1., log_weights,
                                   (1. - inverse_temperature) *
                                   likelihood_diff)
            marginal_loglike_ = reduce_logmeanexp(
                (new_beta - inverse_temperature) * likelihood_log_prob, axis=0)
            new_inverse_temperature = tf.clip_by_value(new_beta, 0., 1.)

            return marginal_loglike_, new_inverse_temperature, log_weights

        def mutate(current_state, log_scalings, num_steps,
                   inverse_temperature):
            """Mutate the state using a Transition kernel."""
            with tf.name_scope('mutate_states'):
                scalings = tf.exp(log_scalings)
                kernel = make_kernel_fn(make_tempered_target_log_prob_fn(
                    prior_log_prob_fn, likelihood_log_prob_fn,
                    inverse_temperature),
                                        current_state,
                                        scalings,
                                        seed=seed_stream)
                pkr = kernel.bootstrap_results(current_state)
                kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)

                def mutate_onestep(i, state, pkr, log_accept_prob_sum):
                    next_state, next_kernel_results = kernel.one_step(
                        state, pkr)
                    kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)
                    log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.)
                    log_accept_prob_sum = log_add_exp(log_accept_prob_sum,
                                                      log_accept_prob)
                    return i + 1, next_state, next_kernel_results, log_accept_prob_sum

                (
                    _, next_state, next_kernel_results, log_accept_prob_sum
                ) = tf.while_loop(
                    cond=lambda i, *args: i < num_steps,
                    body=mutate_onestep,
                    loop_vars=(
                        tf.zeros([], dtype=tf.int32),
                        current_state,
                        pkr,
                        # we accumulate the acceptance probability in log space.
                        tf.fill(
                            ps.shape(kernel_log_accept_ratio),
                            tf.constant(-np.inf,
                                        kernel_log_accept_ratio.dtype))),
                    parallel_iterations=parallel_iterations)
                _, kernel_target_log_prob = gather_mh_like_result(
                    next_kernel_results)
                avg_log_accept_prob_per_particle = log_accept_prob_sum - tf.math.log(
                    tf.cast(num_steps + 1, log_accept_prob_sum.dtype))
                return (next_state, avg_log_accept_prob_per_particle,
                        kernel_target_log_prob)

        # One SMC steps.
        def smc_body_fn(stage, state, smc_kernel_result):
            """Run one stage of SMC with constant temperature."""
            (new_marginal, new_inv_temperature,
             log_weights) = update_weights_temperature(
                 smc_kernel_result.inverse_temperature,
                 smc_kernel_result.particle_info.likelihood_log_prob)
            # TODO(b/152412213) Use a tf.scan to better collect debug info.
            if PRINT_DEBUG:
                tf.print(
                    'Stage:', stage, 'Beta:', new_inv_temperature, 'n_steps:',
                    smc_kernel_result.num_steps, 'accept:',
                    tf.exp(
                        reduce_logmeanexp(
                            smc_kernel_result.particle_info.log_accept_prob,
                            axis=0)), 'scaling:',
                    tf.exp(
                        reduce_logmeanexp(
                            smc_kernel_result.particle_info.log_scalings,
                            axis=0)))
            (resampled_state,
             resampled_particle_info), _ = resample_particle_and_info(
                 (state, smc_kernel_result.particle_info),
                 log_weights,
                 seed=seed_stream)
            next_num_steps, next_log_scalings = tuning_fn(
                smc_kernel_result.num_steps,
                resampled_particle_info.log_scalings,
                resampled_particle_info.log_accept_prob)
            # Skip tuning at stage 0.
            next_num_steps = tf.where(stage == 0, smc_kernel_result.num_steps,
                                      next_num_steps)
            next_log_scalings = tf.where(stage == 0,
                                         resampled_particle_info.log_scalings,
                                         next_log_scalings)
            next_num_steps = tf.clip_by_value(next_num_steps, min_num_steps,
                                              max_num_steps)

            next_state, log_accept_prob, tempered_log_prob = mutate(
                resampled_state, next_log_scalings, next_num_steps,
                new_inv_temperature)
            next_pkr = SMCResults(
                num_steps=next_num_steps,
                inverse_temperature=new_inv_temperature,
                log_marginal_likelihood=(
                    new_marginal + smc_kernel_result.log_marginal_likelihood),
                particle_info=ParticleInfo(
                    log_accept_prob=log_accept_prob,
                    log_scalings=next_log_scalings,
                    tempered_log_prob=tempered_log_prob,
                    likelihood_log_prob=likelihood_log_prob_fn(*next_state),
                ))
            return stage + 1, next_state, next_pkr

        (n_stage, final_state, final_kernel_results) = tf.while_loop(
            cond=lambda i, state, pkr: (  # pylint: disable=g-long-lambda
                (i < max_stage) & tf.reduce_any(pkr.inverse_temperature < 1.)),
            body=smc_body_fn,
            loop_vars=(tf.zeros([],
                                dtype=tf.int32), current_state, current_pkr),
            parallel_iterations=parallel_iterations)
        if unwrap_state_list:
            final_state = final_state[0]
        return n_stage, final_state, final_kernel_results
示例#10
0
    def loop_tree_doubling(self, step_size, momentum_state_memory,
                           current_step_meta_info, iter_, initial_step_state,
                           initial_step_metastate):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            batch_shape = prefer_static.shape(
                current_step_meta_info.init_energy)
            direction = tf.cast(tf.random.uniform(shape=batch_shape,
                                                  minval=0,
                                                  maxval=2,
                                                  dtype=tf.int32,
                                                  seed=self._seed_stream()),
                                dtype=tf.bool)

            tree_start_states = tf.nest.map_structure(
                lambda v: tf.where(  # pylint: disable=g-long-lambda
                    _rightmost_expand_to_rank(
                        direction, prefer_static.rank(v[1])), v[1], v[0]),
                initial_step_state)

            directions_expanded = [
                _rightmost_expand_to_rank(direction, prefer_static.rank(state))
                for state in tree_start_states.state
            ]

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(d, ss, -ss)
                    for d, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state, tree_final_states, final_not_divergence,
                continue_tree_final, energy_diff_tree_sum,
                momentum_tree_cumsum, leapfrogs_taken
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                current_step_meta_info,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory)

            last_candidate_state = initial_step_metastate.candidate_state
            tree_weight = candidate_tree_state.weight
            if MULTINOMIAL_SAMPLE:
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-tf.random.uniform(shape=batch_shape,
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(choose_new_state,
                                                  prefer_static.rank(s0)), s0,
                        s1) for s0, s1 in zip(candidate_tree_state.state,
                                              last_candidate_state.state)
                ],
                target=tf.where(
                    _rightmost_expand_to_rank(
                        choose_new_state,
                        prefer_static.rank(candidate_tree_state.target)),
                    candidate_tree_state.target, last_candidate_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(choose_new_state,
                                                  prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            candidate_tree_state.target_grad_parts,
                            last_candidate_state.target_grad_parts)
                ],
                energy=tf.where(
                    _rightmost_expand_to_rank(
                        choose_new_state,
                        prefer_static.rank(candidate_tree_state.target)),
                    candidate_tree_state.energy, last_candidate_state.energy),
                weight=weight_sum)

            # Update left right information of the trajectory, and check trajectory
            # level U turn
            tree_otherend_states = tf.nest.map_structure(
                lambda v: tf.where(  # pylint: disable=g-long-lambda
                    _rightmost_expand_to_rank(
                        direction, prefer_static.rank(v[1])), v[0], v[1]),
                initial_step_state)

            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    tf.stack(
                        [  # pylint: disable=g-complex-comprehension
                            tf.where(
                                _rightmost_expand_to_rank(
                                    direction, prefer_static.rank(l)), r, l),
                            tf.where(
                                _rightmost_expand_to_rank(
                                    direction, prefer_static.rank(l)), l, r),
                        ],
                        axis=0)
                    for l, r in zip(tf.nest.flatten(tree_final_states),
                                    tf.nest.flatten(tree_otherend_states))
                ])

            if GENERALIZED_UTURN:
                state_diff = momentum_tree_cumsum
            else:
                state_diff = [s[1] - s[0] for s in new_step_state.state]

            no_u_turns_trajectory = has_not_u_turn(
                state_diff, [m[0] for m in new_step_state.momentum],
                [m[1] for m in new_step_state.momentum],
                log_prob_rank=len(batch_shape))

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                energy_diff_sum=(energy_diff_tree_sum +
                                 initial_step_metastate.energy_diff_sum),
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, new_step_state, new_step_metastate
示例#11
0
 def _inverse_log_det_jacobian(self, y):
     return tf.zeros([], dtype=y.dtype)
示例#12
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(self.name + '.one_step'):
            unwrap_state_list = not tf.nest.is_nested(current_state)
            if unwrap_state_list:
                current_state = [current_state]

            current_target_log_prob = previous_kernel_results.target_log_prob
            [init_momentum, init_energy, log_slice_sample
             ] = self._start_trajectory_batched(current_state,
                                                current_target_log_prob)

            def _copy(v):
                return v * prefer_static.ones(prefer_static.pad(
                    [2],
                    paddings=[[0, prefer_static.rank(v)]],
                    constant_values=1),
                                              dtype=v.dtype)

            initial_state = TreeDoublingState(
                momentum=init_momentum,
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.grads_target_log_prob
            )
            initial_step_state = tf.nest.map_structure(_copy, initial_state)

            if MULTINOMIAL_SAMPLE:
                init_weight = tf.zeros_like(init_energy)
            else:
                init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE)

            candidate_state = TreeDoublingStateCandidate(
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.
                grads_target_log_prob,
                energy=init_energy,
                weight=init_weight)

            initial_step_metastate = TreeDoublingMetaState(
                candidate_state=candidate_state,
                is_accepted=tf.zeros_like(init_energy, dtype=tf.bool),
                energy_diff_sum=tf.zeros_like(init_energy),
                leapfrog_count=tf.zeros_like(init_energy,
                                             dtype=TREE_COUNT_DTYPE),
                continue_tree=tf.ones_like(init_energy, dtype=tf.bool),
                not_divergence=tf.ones_like(init_energy, dtype=tf.bool))

            # Convert the write/read instruction into TensorArray so that it is
            # compatible with XLA.
            write_instruction = tf.TensorArray(
                TREE_COUNT_DTYPE,
                size=2**(self.max_tree_depth - 1),
                clear_after_read=False).unstack(self._write_instruction)
            read_instruction = tf.TensorArray(
                tf.int32,
                size=2**(self.max_tree_depth - 1),
                clear_after_read=False).unstack(self._read_instruction)

            current_step_meta_info = OneStepMetaInfo(
                log_slice_sample=log_slice_sample,
                init_energy=init_energy,
                write_instruction=write_instruction,
                read_instruction=read_instruction)

            _, _, new_step_metastate = tf.while_loop(
                cond=lambda iter_, state, metastate: (  # pylint: disable=g-long-lambda
                    ((iter_ < self.max_tree_depth) & tf.reduce_any(
                        metastate.continue_tree))),
                body=lambda iter_, state, metastate: self.loop_tree_doubling(  # pylint: disable=g-long-lambda
                    previous_kernel_results.step_size, previous_kernel_results.
                    momentum_state_memory, current_step_meta_info, iter_,
                    state, metastate),
                loop_vars=(tf.zeros([], dtype=tf.int32, name='iter'),
                           initial_step_state, initial_step_metastate),
                parallel_iterations=TF_WHILE_PARALLEL_ITERATIONS,
            )

            kernel_results = NUTSKernelResults(
                target_log_prob=new_step_metastate.candidate_state.target,
                grads_target_log_prob=(
                    new_step_metastate.candidate_state.target_grad_parts),
                momentum_state_memory=previous_kernel_results.
                momentum_state_memory,
                step_size=previous_kernel_results.step_size,
                log_accept_ratio=tf.math.log(
                    new_step_metastate.energy_diff_sum /
                    tf.cast(new_step_metastate.leapfrog_count,
                            dtype=new_step_metastate.energy_diff_sum.dtype)),
                # TODO(junpenglao): return non-cumulated leapfrogs_taken once
                # benchmarking is done.
                leapfrogs_taken=(previous_kernel_results.leapfrogs_taken +
                                 new_step_metastate.leapfrog_count *
                                 self.unrolled_leapfrog_steps),
                is_accepted=new_step_metastate.is_accepted,
                reach_max_depth=new_step_metastate.continue_tree,
                has_divergence=~new_step_metastate.not_divergence,
                energy=new_step_metastate.candidate_state.energy)

            result_state = new_step_metastate.candidate_state.state
            if unwrap_state_list:
                result_state = result_state[0]

            return result_state, kernel_results
示例#13
0
def transport_implicit_gradients(derivative_cost, transport_matrix, eps, b,
                                 d_p):
    """Application of the transpose of the Jacobians dP/dx and dP/db.

  This is applied to a perturbation of the size of the transport matrix.
  Required to back-propagate through Sinkhorn's output.

  Args:
   derivative_cost: the derivative of the cost function.
   transport_matrix: the obtained transport matrix tensor.
   eps: the value of the entropic regualarization parameter.
   b: the target weights.
   d_p: the perturbation of the transport matrix.

  Returns:
   A list of two tensor that correspond to the application of the transpose
   of dP/dx and dP/db on dP.
  """
    batch_size = tf.shape(b)[0]
    m = tf.shape(b)[1]
    invmargin1 = tf.math.reciprocal(tf.reduce_sum(transport_matrix, axis=2))
    m1 = invmargin1[:, 1:, tf.newaxis] * transport_matrix[:, 1:, :]
    m1 = tf.concat(
        [tf.zeros([tf.shape(m1)[0], 1, tf.shape(m1)[2]]), m1], axis=1)

    invmargin2 = tf.math.reciprocal(tf.reduce_sum(transport_matrix, axis=1))
    m2 = invmargin2[:, :, tf.newaxis] * tf.transpose(transport_matrix,
                                                     [0, 2, 1])
    eye_m = tf.eye(m, batch_shape=[batch_size])
    schur = eye_m - tf.linalg.matmul(m2, m1)

    def jac_b_p_transpose(d_p):
        """Transposed of the jacobian of the transport w.r.t the target weights."""
        d_p_p = d_p * transport_matrix
        u_f = tf.reduce_sum(d_p_p, axis=2) / eps
        u_g = tf.reduce_sum(d_p_p, axis=1) / eps

        m1_tranpose_u_f = tf.linalg.matvec(m1, u_f, transpose_a=True)
        to_invert = tf.concat(
            [m1_tranpose_u_f[:, :, tf.newaxis], u_g[:, :, tf.newaxis]], axis=2)
        inverses = tf.linalg.solve(tf.transpose(schur, [0, 2, 1]), to_invert)
        inv_m1_tranpose_u_f, inv_u_g = inverses[:, :, 0], inverses[:, :, 1]
        jac_2 = -inv_m1_tranpose_u_f + inv_u_g
        return eps * jac_2 / b

    def jac_x_p_transpose(d_p):
        """Transposed of the jacobian of the transport w.r.t the inputs."""
        d_p_p = d_p * transport_matrix
        c_x = -tf.reduce_sum(derivative_cost * d_p_p, axis=2) / eps
        u_f = tf.math.reduce_sum(d_p_p, axis=2) / eps
        u_g = tf.math.reduce_sum(d_p_p, axis=1) / eps
        m1_tranpose_u_f = tf.linalg.matvec(m1, u_f, transpose_a=True)
        to_invert = tf.concat(
            [m1_tranpose_u_f[:, :, tf.newaxis], u_g[:, :, tf.newaxis]], axis=2)
        inverses = tf.linalg.solve(tf.transpose(schur, [0, 2, 1]), to_invert)
        inv_m1_tranpose_u_f, inv_u_g = inverses[:, :, 0], inverses[:, :, 1]
        jac_1 = u_f + tf.linalg.matvec(
            m2, inv_m1_tranpose_u_f - inv_u_g, transpose_a=True)
        jac_2 = -inv_m1_tranpose_u_f + inv_u_g
        jac_1 = jac_1 * tf.reduce_sum(m1 * derivative_cost, axis=2)
        jac_2 = tf.linalg.matvec(
            tf.transpose(m2, [0, 2, 1]) * derivative_cost, jac_2)
        return c_x + jac_1 + jac_2

    return [jac_x_p_transpose(d_p), jac_b_p_transpose(d_p)]
 def testNonEmptyConstantTensor(self):
     x = tf.zeros((2, 3, 4))
     value = distribution_util.prefer_static_value(x)
     self.assertIsInstance(value, np.ndarray)
     self.assertAllEqual(np.zeros((2, 3, 4)), value)
示例#15
0
 def _batch_of_zeros_with_rightmost_singletons(n_singletons):
     """Return Tensor of zeros with some singletons on the rightmost dims."""
     ones = tf.ones(shape=[n_singletons], dtype=tf.int32)
     return tf.zeros(shape=tf.concat([batch_shape, ones], axis=0),
                     dtype=dtype)
示例#16
0
 def _entropy(self):
     return tf.zeros(self.batch_shape_tensor(), dtype=self.dtype)
示例#17
0
def _batch_interp_with_gather_nd(x, x_ref_min, x_ref_max, y_ref, nd,
                                 fill_value, batch_dims):
    """N-D interpolation that works with leading batch dims."""
    dtype = x.dtype

    # In this function,
    # x.shape = [A1, ..., An, D, nd], where n = batch_dims
    # and
    # y_ref.shape = [A1, ..., An, C1, C2,..., Cnd, B1,...,BM]
    # y_ref[A1, ..., An, i1,...,ind] is a shape [B1,...,BM] Tensor with the value
    # at index [i1,...,ind] in the interpolation table.
    #  and x_ref_max have shapes [A1, ..., An, nd].

    # ny[k] is number of y reference points in interp dim k.
    ny = tf.cast(tf.shape(y_ref)[batch_dims:batch_dims + nd], dtype)

    # Map [x_ref_min, x_ref_max] to [0, ny - 1].
    # This is the (fractional) index of x.
    # x_idx_unclipped[A1, ..., An, d, k] is the fractional index into dim k of
    # interpolation table for the dth x value.
    x_ref_min_expanded = tf.expand_dims(x_ref_min, axis=-2)
    x_ref_max_expanded = tf.expand_dims(x_ref_max, axis=-2)
    x_idx_unclipped = (ny - 1) * (x - x_ref_min_expanded) / (
        x_ref_max_expanded - x_ref_min_expanded)

    # Wherever x is NaN, x_idx_unclipped will be NaN as well.
    # Keep track of the nan indices here (so we can impute NaN later).
    # Also eliminate any NaN indices, since there is not NaN in 32bit.
    nan_idx = tf.math.is_nan(x_idx_unclipped)
    x_idx_unclipped = tf.where(nan_idx, 0., x_idx_unclipped)

    # x_idx.shape = [A1, ..., An, D, nd]
    x_idx = tf.clip_by_value(x_idx_unclipped, tf.zeros((), dtype=dtype),
                             ny - 1)

    # Get the index above and below x_idx.
    # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx),
    # however, this results in idx_below == idx_above whenever x is on a grid.
    # This in turn results in y_ref_below == y_ref_above, and then the gradient
    # at this point is zero.  So here we 'jitter' one of idx_below, idx_above,
    # so that they are at different values.  This jittering does not affect the
    # interpolated value, but does make the gradient nonzero (unless of course
    # the y_ref values are the same).
    idx_below = tf.floor(x_idx)
    idx_above = tf.minimum(idx_below + 1, ny - 1)
    idx_below = tf.maximum(idx_above - 1, 0)

    # These are the values of y_ref corresponding to above/below indices.
    # idx_below_int32.shape = x.shape[:-1] + [nd]
    idx_below_int32 = tf.cast(idx_below, dtype=tf.int32)
    idx_above_int32 = tf.cast(idx_above, dtype=tf.int32)

    # idx_below_list is a length nd list of shape x.shape[:-1] int32 tensors.
    idx_below_list = tf.unstack(idx_below_int32, axis=-1)
    idx_above_list = tf.unstack(idx_above_int32, axis=-1)

    # Use t to get a convex combination of the below/above values.
    # t.shape = [A1, ..., An, D, nd]
    t = x_idx - idx_below

    # x, and tensors shaped like x, need to be added to, and selected with
    # (using tf.where) the output y.  This requires appending singletons.
    def _expand_x_fn(tensor):
        # Reshape tensor to tensor.shape + [1] * M.
        extended_shape = tf.concat([
            tf.shape(tensor),
            tf.ones_like(tf.shape(y_ref)[batch_dims + nd:])
        ],
                                   axis=0)
        return tf.reshape(tensor, extended_shape)

    # Now, t.shape = [A1, ..., An, D, nd] + [1] * (rank(y_ref) - nd - batch_dims)
    t = _expand_x_fn(t)
    s = 1 - t

    # Re-insert NaN wherever x was NaN.
    nan_idx = _expand_x_fn(nan_idx)
    t = tf.where(nan_idx, tf.constant(np.nan, dtype), t)

    terms = []
    # Our work above has located x's fractional index inside a cube of above/below
    # indices. The distance to the below indices is t, and to the above indices
    # is s.
    # Drawing lines from x to the cube walls, we get 2**nd smaller cubes. Each
    # term in the result is a product of a reference point, gathered from y_ref,
    # multiplied by a volume.  The volume is that of the cube opposite to the
    # reference point.  E.g. if the reference point is below x in every axis, the
    # volume is that of the cube with corner above x in every axis, s[0]*...*s[nd]
    # We could probably do this with one massive gather, but that would be very
    # unreadable and un-debuggable.  It also would create a large Tensor.
    for zero_ones_list in _binary_count(nd):
        gather_from_y_ref_idx = []
        opposite_volume_t_idx = []
        opposite_volume_s_idx = []
        for k, zero_or_one in enumerate(zero_ones_list):
            if zero_or_one == 0:
                # If the kth iterate has zero_or_one = 0,
                # Will gather from the 'below' reference point along axis k.
                gather_from_y_ref_idx.append(idx_below_list[k])
                # Now append the index to gather for computing opposite_volume.
                # This could be done by initializing opposite_volume to 1, then here:
                #  opposite_volume *= tf.gather(s, indices=k, axis=tf.rank(x) - 1)
                # but that puts a gather in the 'inner loop.'  Better to append the
                # index and do one larger gather down below.
                opposite_volume_s_idx.append(k)
            else:
                gather_from_y_ref_idx.append(idx_above_list[k])
                # Append an index to gather, having the same effect as
                #   opposite_volume *= tf.gather(t, indices=k, axis=tf.rank(x) - 1)
                opposite_volume_t_idx.append(k)

        # Compute opposite_volume (volume of cube opposite the ref point):
        # Recall t.shape = s.shape = [D, nd] + [1, ..., 1]
        # Gather from t and s along the 'nd' axis, which is rank(x) - 1.
        ov_axis = tf.rank(x) - 1
        opposite_volume = (tf.reduce_prod(
            tf.gather(t,
                      indices=tf.cast(opposite_volume_t_idx, dtype=tf.int32),
                      axis=ov_axis),
            axis=ov_axis) * tf.reduce_prod(tf.gather(
                s,
                indices=tf.cast(opposite_volume_s_idx, dtype=tf.int32),
                axis=ov_axis),
                                           axis=ov_axis))  # pyformat: disable

        y_ref_pt = tf.gather_nd(y_ref,
                                tf.stack(gather_from_y_ref_idx, axis=-1),
                                batch_dims=batch_dims)

        terms.append(y_ref_pt * opposite_volume)

    y = tf.math.add_n(terms)

    if tf.debugging.is_numeric_tensor(fill_value):
        # Recall x_idx_unclipped.shape = [D, nd],
        # so here we check if it was out of bounds in any of the nd dims.
        # Thus, oob_idx.shape = [D].
        oob_idx = tf.reduce_any(
            (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1), axis=-1)

        # Now, y.shape = [D, B1,...,BM], so we'll have to broadcast oob_idx.

        oob_idx = _expand_x_fn(oob_idx)  # Shape [D, 1,...,1]
        oob_idx |= tf.fill(tf.shape(y), False)
        y = tf.where(oob_idx, fill_value, y)
    return y
示例#18
0
    def __init__(self,
                 num_timesteps,
                 coefficients,
                 level_scale,
                 initial_state_prior,
                 observation_noise_scale=0.,
                 name=None,
                 **linear_gaussian_ssm_kwargs):
        """Build a state space model implementing an autoregressive process.

    Args:
      num_timesteps: Scalar `int` `Tensor` number of timesteps to model
        with this distribution.
      coefficients: `float` `Tensor` of shape `concat(batch_shape, [order])`
        defining  the autoregressive coefficients. The coefficients are defined
        backwards in time: `coefficients[0] * level[t] + coefficients[1] *
        level[t-1] + ... + coefficients[order-1] * level[t-order+1]`.
      level_scale: Scalar (any additional dimensions are treated as batch
        dimensions) `float` `Tensor` indicating the standard deviation of the
        transition noise at each step.
      initial_state_prior: instance of `tfd.MultivariateNormal`
        representing the prior distribution on latent states.  Must have
        event shape `[order]`.
      observation_noise_scale: Scalar (any additional dimensions are
        treated as batch dimensions) `float` `Tensor` indicating the standard
        deviation of the observation noise.
        Default value: 0.
      name: Python `str` name prefixed to ops created by this class.
        Default value: "AutoregressiveStateSpaceModel".
      **linear_gaussian_ssm_kwargs: Optional additional keyword arguments to
        to the base `tfd.LinearGaussianStateSpaceModel` constructor.
    """
        parameters = dict(locals())
        parameters.update(linear_gaussian_ssm_kwargs)
        del parameters['linear_gaussian_ssm_kwargs']
        with tf.name_scope(name or 'AutoregressiveStateSpaceModel') as name:

            # The initial state prior determines the dtype of sampled values.
            # Other model parameters must have the same dtype.
            dtype = initial_state_prior.dtype

            coefficients = tf.convert_to_tensor(value=coefficients,
                                                name='coefficients',
                                                dtype=dtype)
            level_scale = tf.convert_to_tensor(value=level_scale,
                                               name='level_scale',
                                               dtype=dtype)
            observation_noise_scale = tf.convert_to_tensor(
                value=observation_noise_scale,
                name='observation_noise_scale',
                dtype=dtype)

            order = tf.compat.dimension_value(coefficients.shape[-1])
            if order is None:
                raise ValueError(
                    'Autoregressive coefficients must have static shape.')

            self._order = order
            self._coefficients = coefficients
            self._level_scale = level_scale

            super(AutoregressiveStateSpaceModel, self).__init__(
                num_timesteps=num_timesteps,
                transition_matrix=make_ar_transition_matrix(coefficients),
                transition_noise=tfd.MultivariateNormalDiag(
                    scale_diag=tf.stack([level_scale] +
                                        [tf.zeros_like(level_scale)] *
                                        (self.order - 1),
                                        axis=-1)),
                observation_matrix=tf.concat([
                    tf.ones([1, 1], dtype=dtype),
                    tf.zeros([1, self.order - 1], dtype=dtype)
                ],
                                             axis=-1),
                observation_noise=tfd.MultivariateNormalDiag(
                    scale_diag=observation_noise_scale[..., tf.newaxis]),
                initial_state_prior=initial_state_prior,
                name=name,
                **linear_gaussian_ssm_kwargs)
            self._parameters = parameters
示例#19
0
 def normal_cdf(r):
     r = tf.convert_to_tensor(value=r, name='r')
     n = tfd.Normal(loc=tf.zeros([], r.dtype.base_dtype),
                    scale=tf.ones([], r.dtype.base_dtype))
     return n.cdf(r)
示例#20
0
 def params_model_fn(out_channels, size, in_channels, dtype):
     yield Root(
         tfd.LogNormal(tf.zeros(
             list(size) + [in_channels, out_channels], dtype),
                       1.,
                       name='kernel'))
示例#21
0
def reduce_audio_in_batch(tensor, hparams=None, is_training=True):
    instrument_count = hparams.timbre_training_max_instruments
    note_croppping_list = []
    instrument_family_list = []
    samples_list = []
    max_length = 0
    for i in range(instrument_count):
        pitch = tensor['pitch'][i]
        # Move the audio so there are different attack times.
        start_idx = tf.random.uniform((),
                                      minval=0,
                                      maxval=hparams.timbre_max_start_offset,
                                      dtype='int64')
        samples = K.concatenate(
            [tf.zeros(start_idx),
             tf.sparse.to_dense(tensor['audio'])[i]])

        end_idx = (
            start_idx +
            tf.py_function(_get_approx_note_length,
                           [tf.sparse.to_dense(tensor['audio'])[i]], tf.int64))
        if hparams.timbre_max_len and end_idx > hparams.timbre_max_len:
            samples = tf.slice(samples,
                               begin=[0],
                               size=[hparams.timbre_max_len])
            end_idx = hparams.timbre_max_len
        if len(samples) > max_length:
            max_length = len(samples)

        samples_list.append(samples)

        instrument_family = tensor['instrument_family'][i]
        note_croppping_list.append(
            timbre_dataset_util.NoteCropping(pitch=pitch,
                                             start_idx=start_idx,
                                             end_idx=end_idx))
        instrument_family_list.append(
            tf.one_hot(tf.cast(instrument_family, tf.int32),
                       hparams.timbre_num_classes))

    # Pad the end of the shorter audio clips.
    samples_list = list(
        map(lambda x: tf.pad(x, [[0, max_length - len(x)]]), samples_list))

    combined_samples = (
        tf.reduce_sum(tf.convert_to_tensor(samples_list), axis=0) /
        instrument_count)

    # Ensure all audios in batches are the same length.
    if hparams.timbre_max_len:
        pad_length = hparams.timbre_max_len
    else:
        pad_length = hparams.timbre_max_start_offset + 5 * hparams.sample_rate
    combined_samples = tf.pad(
        combined_samples, [[0, pad_length - tf.shape(combined_samples)[0]]])
    note_croppings = tf.convert_to_tensor(note_croppping_list, dtype=tf.int32)
    instrument_families = tf.convert_to_tensor(instrument_family_list,
                                               dtype=tf.int32)

    wav_data = tf.py_function(
        lambda x: audio_io.samples_to_wav_data(
            x.numpy(), sample_rate=hparams.sample_rate), [combined_samples],
        tf.string)

    return dict(
        audio=wav_data,
        note_croppings=note_croppings,
        instrument_families=instrument_families,
    )
示例#22
0
def soft_multivariate_quantiles(x, quantiles, quantile_width=None, **kwargs):
    """Computes soft multivariate quantiles via optimal transport.

  Transport multivariate input values in x onto 2^d + 1 weighted points,
  {0,1}^d + [0.5, ..., 0.5]. Target weights are adjusted so
  that those values in x that are transported to the middle value in the target
  vector correspond to those concentrating around the quantile of interest.

  Args:
   x: Tensor<float> of shape [batch, N, d]
   quantiles: Tensor<float> of shape [r, d], r targeted quantiles of dimension d
   quantile_width: (float) mass given to the bucket supposed to attract points
     whose value concentrate around the desired quantile value. Bigger width
     means that we allow the soft quantile to be a mixture of more points
     further away from the quantile. If None, the width is set at 1/n where n is
     the number of values considered (the size along the 'axis').
   **kwargs: see sinkhorn.autodiff_sinkhorn for possible extra parameters.

  Returns:
    A Tensor<float> [N,r,d] of multivariate quantiles per batch.

  """
    quantiles = tf.constant(quantiles, tf.float32)
    batch_size = x.shape[0]
    n = tf.cast(x.shape[1], tf.float32)
    d = x.shape[2]
    if quantile_width is None:
        quantile_width = 2 / n
    num_quantiles = tf.shape(quantiles)[0]
    hypercube_vertices = tf.constant(
        list(itertools.product([-1, 1], repeat=d)), tf.float32)
    # weights attached to vertices for each quantile. this is n_quantiles x 2^r
    weights = quantiles[:,
                        tf.newaxis, :]**(0.5 *
                                         (1 - hypercube_vertices))[tf.newaxis,
                                                                   Ellipsis]
    weights *= (1 - quantiles)[:, tf.newaxis, :]**(
        0.5 * (1 + hypercube_vertices))[tf.newaxis, Ellipsis]

    weights = (1 - quantile_width) * tf.reduce_prod(weights, axis=2)
    # adding weights for quantile itself (in position 0).
    weights = tf.concat((quantile_width * tf.ones(
        (num_quantiles, 1)), weights),
                        axis=1)
    # augmenting and formating as batch_size * 2^r +1 * num_quantiles
    weights = tf.reshape(tf.tile(tf.transpose(weights), [batch_size, 1]),
                         [batch_size, 2**d + 1, num_quantiles])
    # set target locations, by adding the point at 0 that will absorb the quantile
    # augment it with batch_size
    y = tf.concat((tf.zeros((1, d), dtype=tf.float32), hypercube_vertices),
                  axis=0)
    y = tf.reshape(tf.tile(y, [batch_size, 1]), [batch_size, 2**d + 1, d])
    # center x
    x_mean = tf.reduce_mean(x, axis=1)
    x = x - x_mean[:, tf.newaxis, :]
    transports = sinkhorn.autodiff_sinkhorn(
        x, y,
        tf.ones([batch_size, n, num_quantiles], dtype=tf.float32) / n, weights,
        **kwargs)

    # recover convex combinations resulting from transporting to central point in
    # in all batches and quantile variations.
    transports = 1 / quantile_width * tf.reshape(transports[:, :, 0, :],
                                                 [batch_size, n, -1])
    # apply these convex combinations to data points + recenter.
    all_soft_quantiles = tf.reduce_sum(
        transports[:, :, :, tf.newaxis] * x[:, :, tf.newaxis, :],
        axis=1) + x_mean[:, tf.newaxis, :]
    # reshape those quantiles after having applied convex combinations.
    return tf.reshape(all_soft_quantiles, [batch_size, num_quantiles, d])
示例#23
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
        start_trajectory_seed, loop_seed = samplers.split_seed(seed)

        with tf.name_scope(self.name + '.one_step'):
            state_structure = current_state
            current_state = tf.nest.flatten(current_state)
            if (tf.nest.is_nested(state_structure)
                    and (not mcmc_util.is_list_like(state_structure)
                         or len(current_state) != len(state_structure))):
                # TODO(b/170865194): Support dictionaries and other non-list-like state.
                raise TypeError(
                    'NUTS does not currently support nested or '
                    'non-list-like state structures (saw: {}).'.format(
                        state_structure))

            current_target_log_prob = previous_kernel_results.target_log_prob
            [init_momentum, init_energy, log_slice_sample
             ] = self._start_trajectory_batched(current_state,
                                                current_target_log_prob,
                                                seed=start_trajectory_seed)

            def _copy(v):
                return v * ps.ones(ps.pad(
                    [2], paddings=[[0, ps.rank(v)]], constant_values=1),
                                   dtype=v.dtype)

            initial_state = TreeDoublingState(
                momentum=init_momentum,
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.grads_target_log_prob
            )
            initial_step_state = tf.nest.map_structure(_copy, initial_state)

            if MULTINOMIAL_SAMPLE:
                init_weight = tf.zeros_like(init_energy)  # log(exp(H0 - H0))
            else:
                init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE)

            candidate_state = TreeDoublingStateCandidate(
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.
                grads_target_log_prob,
                energy=init_energy,
                weight=init_weight)

            initial_step_metastate = TreeDoublingMetaState(
                candidate_state=candidate_state,
                is_accepted=tf.zeros_like(init_energy, dtype=tf.bool),
                momentum_sum=init_momentum,
                energy_diff_sum=tf.zeros_like(init_energy),
                leapfrog_count=tf.zeros_like(init_energy,
                                             dtype=TREE_COUNT_DTYPE),
                continue_tree=tf.ones_like(init_energy, dtype=tf.bool),
                not_divergence=tf.ones_like(init_energy, dtype=tf.bool))

            # Convert the write/read instruction into TensorArray so that it is
            # compatible with XLA.
            write_instruction = tf.TensorArray(
                TREE_COUNT_DTYPE,
                size=len(self._write_instruction),
                clear_after_read=False).unstack(self._write_instruction)
            read_instruction = tf.TensorArray(tf.int32,
                                              size=len(self._read_instruction),
                                              clear_after_read=False).unstack(
                                                  self._read_instruction)

            current_step_meta_info = OneStepMetaInfo(
                log_slice_sample=log_slice_sample,
                init_energy=init_energy,
                write_instruction=write_instruction,
                read_instruction=read_instruction)

            _, _, _, new_step_metastate = tf.while_loop(
                cond=lambda iter_, seed, state, metastate: (  # pylint: disable=g-long-lambda
                    (iter_ < self.max_tree_depth) & tf.reduce_any(
                        metastate.continue_tree)),
                body=lambda iter_, seed, state, metastate: self.
                _loop_tree_doubling(  # pylint: disable=g-long-lambda
                    previous_kernel_results.step_size, previous_kernel_results.
                    momentum_state_memory, current_step_meta_info, iter_,
                    state, metastate, seed),
                loop_vars=(tf.zeros([], dtype=tf.int32,
                                    name='iter'), loop_seed,
                           initial_step_state, initial_step_metastate),
                parallel_iterations=self.parallel_iterations,
            )

            kernel_results = NUTSKernelResults(
                target_log_prob=new_step_metastate.candidate_state.target,
                grads_target_log_prob=(
                    new_step_metastate.candidate_state.target_grad_parts),
                momentum_state_memory=previous_kernel_results.
                momentum_state_memory,
                step_size=previous_kernel_results.step_size,
                log_accept_ratio=tf.math.log(
                    new_step_metastate.energy_diff_sum /
                    tf.cast(new_step_metastate.leapfrog_count,
                            dtype=new_step_metastate.energy_diff_sum.dtype)),
                leapfrogs_taken=(new_step_metastate.leapfrog_count *
                                 self.unrolled_leapfrog_steps),
                is_accepted=new_step_metastate.is_accepted,
                reach_max_depth=new_step_metastate.continue_tree,
                has_divergence=~new_step_metastate.not_divergence,
                energy=new_step_metastate.candidate_state.energy,
                seed=seed,
            )

            result_state = tf.nest.pack_sequence_as(
                state_structure, new_step_metastate.candidate_state.state)
            return result_state, kernel_results
    def _sample_n(self, n, seed=None):
        dim0_seed, otherdims_seed = samplers.split_seed(
            seed, salt='von_mises_fisher')
        # The sampling strategy relies on the fact that vMF variates are symmetric
        # about the mean direction. Accordingly, if we have a sampling strategy for
        # the away-from-mean angle, then we can uniformly sample the remaining
        # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
        # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
        #
        # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
        # von-Mises distributed `x` value in [-1, 1], then uniformly select what
        # amounts to a "up" or "down" additional degree of freedom after unit
        # normalizing, followed by a final rotation to the desired mean direction
        # from a basis of (1, 0).
        #
        # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
        # unit sphere over which the distribution is uniform, in particular the
        # circle where x = \hat{x} intersects the unit sphere. We pick a point on
        # that circle, then rotate to the desired mean direction from a basis of
        # (1, 0, 0).
        mean_direction = tf.convert_to_tensor(self.mean_direction)
        concentration = tf.convert_to_tensor(self.concentration)
        event_dim = (
            tf.compat.dimension_value(self.event_shape[0])
            or self._event_shape_tensor(mean_direction=mean_direction)[0])

        sample_batch_shape = ps.concat(
            [[n],
             self._batch_shape_tensor(mean_direction=mean_direction,
                                      concentration=concentration)],
            axis=0)
        dim = tf.cast(event_dim - 1, self.dtype)
        if event_dim == 3:
            samples_dim0 = self._sample_3d(n,
                                           mean_direction=mean_direction,
                                           concentration=concentration,
                                           seed=dim0_seed)
        else:
            # Wood'94 provides a rejection algorithm to sample the x coordinate.
            # Wood'94 definition of b:
            # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
            # https://stats.stackexchange.com/questions/156729 suggests:
            b = dim / (2 * concentration +
                       tf.sqrt(4 * concentration**2 + dim**2))
            # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
            #     https://github.com/nicola-decao/s-vae-tf/
            x = (1 - b) / (1 + b)
            c = concentration * x + dim * tf.math.log1p(-x**2)
            beta = beta_lib.Beta(dim / 2, dim / 2)

            def cond_fn(w, should_continue, seed):
                del w, seed
                return tf.reduce_any(should_continue)

            def body_fn(w, should_continue, seed):
                """While loop body for sampling the angle `w`."""
                beta_seed, unif_seed, next_seed = samplers.split_seed(seed,
                                                                      n=3)
                z = beta.sample(sample_shape=sample_batch_shape,
                                seed=beta_seed)
                # set_shape needed here because of b/139013403
                tensorshape_util.set_shape(z, w.shape)
                w = tf.where(should_continue,
                             (1. - (1. + b) * z) / (1. - (1. - b) * z), w)
                if not self.allow_nan_stats:
                    w = tf.debugging.check_numerics(w, 'w')
                unif = samplers.uniform(sample_batch_shape,
                                        seed=unif_seed,
                                        dtype=self.dtype)
                # set_shape needed here because of b/139013403
                tensorshape_util.set_shape(unif, w.shape)
                should_continue = should_continue & (
                    concentration * w + dim * tf.math.log1p(-x * w) - c <
                    # Use log1p(-unif) to prevent log(0) and ensure that log(1) is
                    # possible.
                    tf.math.log1p(-unif))
                return w, should_continue, next_seed

            w = tf.zeros(sample_batch_shape, dtype=self.dtype)
            should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
            samples_dim0, _, _ = tf.while_loop(cond=cond_fn,
                                               body=body_fn,
                                               loop_vars=(w, should_continue,
                                                          dim0_seed))
            samples_dim0 = samples_dim0[..., tf.newaxis]
        if not self._allow_nan_stats:
            # Verify samples are w/in -1, 1, with useful error output tensors (top
            # value rather than all values).
            with tf.control_dependencies([
                    assert_util.assert_less_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(1.01)),
                    assert_util.assert_greater_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(-1.01)),
            ]):
                samples_dim0 = tf.identity(samples_dim0)
        samples_otherdims_shape = ps.concat(
            [sample_batch_shape, [event_dim - 1]], axis=0)
        unit_otherdims = tf.math.l2_normalize(samplers.normal(
            samples_otherdims_shape, seed=otherdims_seed, dtype=self.dtype),
                                              axis=-1)
        samples = tf.concat(
            [
                samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
                tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
            ],
            axis=-1)
        samples = tf.math.l2_normalize(samples, axis=-1)
        if not self.allow_nan_stats:
            samples = tf.debugging.check_numerics(samples, 'samples')

        # Runtime assert that samples are unit length.
        if not self.allow_nan_stats:
            worst, _ = tf.math.top_k(
                tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
            with tf.control_dependencies([
                    assert_util.assert_near(dtype_util.as_numpy_dtype(
                        self.dtype)(0),
                                            worst,
                                            atol=1e-4,
                                            summarize=100)
            ]):
                samples = tf.identity(samples)
        # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
        # Now, we move the mode to `self.mean_direction` using a rotation matrix.
        if not self.allow_nan_stats:
            # Assert that the basis vector rotates to the mean direction, as expected.
            basis = tf.cast(
                tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                self.dtype)
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.linalg.norm(self._rotate(
                            basis, mean_direction=mean_direction) -
                                       mean_direction,
                                       axis=-1),
                        dtype_util.as_numpy_dtype(self.dtype)(1e-5))
            ]):
                return self._rotate(samples, mean_direction=mean_direction)
        return self._rotate(samples, mean_direction=mean_direction)
示例#25
0
    def _build_sub_tree(self,
                        directions,
                        integrator,
                        current_step_meta_info,
                        nsteps,
                        initial_state,
                        continue_tree,
                        not_divergence,
                        momentum_state_memory,
                        seed,
                        name=None):
        with tf.name_scope('build_sub_tree'):
            batch_shape = ps.shape(current_step_meta_info.init_energy)
            # We never want to select the inital state
            if MULTINOMIAL_SAMPLE:
                init_weight = tf.fill(
                    batch_shape,
                    tf.constant(
                        -np.inf,
                        dtype=current_step_meta_info.init_energy.dtype))
            else:
                init_weight = tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE)

            init_momentum_cumsum = [
                tf.zeros_like(x) for x in initial_state.momentum
            ]
            initial_state_candidate = TreeDoublingStateCandidate(
                state=initial_state.state,
                target=initial_state.target,
                target_grad_parts=initial_state.target_grad_parts,
                energy=initial_state.target,
                weight=init_weight)
            energy_diff_sum = tf.zeros_like(current_step_meta_info.init_energy,
                                            name='energy_diff_sum')
            [
                _,
                _,
                energy_diff_tree_sum,
                momentum_tree_cumsum,
                leapfrogs_taken,
                final_state,
                candidate_tree_state,
                final_continue_tree,
                final_not_divergence,
                momentum_state_memory,
            ] = tf.while_loop(
                cond=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum,  # pylint: disable=g-long-lambda
                leapfrogs_taken, state, state_c, continue_tree, not_divergence,
                momentum_state_memory: (
                    (iter_ < nsteps) & tf.reduce_any(continue_tree)),
                body=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum,  # pylint: disable=g-long-lambda
                leapfrogs_taken, state, state_c, continue_tree, not_divergence,
                momentum_state_memory: (self._loop_build_sub_tree(
                    directions, integrator, current_step_meta_info, iter_,
                    energy_diff_sum, init_momentum_cumsum, leapfrogs_taken,
                    state, state_c, continue_tree, not_divergence,
                    momentum_state_memory, seed)),
                loop_vars=(
                    tf.zeros([], dtype=tf.int32, name='iter'),
                    seed,
                    energy_diff_sum,
                    init_momentum_cumsum,
                    tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE),
                    initial_state,
                    initial_state_candidate,
                    continue_tree,
                    not_divergence,
                    momentum_state_memory,
                ),
                parallel_iterations=self.parallel_iterations)

        return (
            candidate_tree_state,
            final_state,
            final_not_divergence,
            final_continue_tree,
            energy_diff_tree_sum,
            momentum_tree_cumsum,
            leapfrogs_taken,
        )
示例#26
0
    def __init__(self,
                 kernel,
                 index_points=None,
                 mean_fn=None,
                 observation_noise_variance=0.,
                 jitter=1e-6,
                 validate_args=False,
                 allow_nan_stats=False,
                 name='GaussianProcess'):
        """Instantiate a GaussianProcess Distribution.

    Args:
      kernel: `PositiveSemidefiniteKernel`-like instance representing the
        GP's covariance function.
      index_points: `float` `Tensor` representing finite (batch of) vector(s) of
        points in the index set over which the GP is defined. Shape has the form
        `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature
        dimensions and must equal `kernel.feature_ndims` and `e` is the number
        (size) of index points in each batch. Ultimately this distribution
        corresponds to a `e`-dimensional multivariate normal. The batch shape
        must be broadcastable with `kernel.batch_shape` and any batch dims
        yielded by `mean_fn`.
      mean_fn: Python `callable` that acts on `index_points` to produce a (batch
        of) vector(s) of mean values at `index_points`. Takes a `Tensor` of
        shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is
        broadcastable with `[b1, ..., bB]`. Default value: `None` implies
        constant zero function.
      observation_noise_variance: `float` `Tensor` representing (batch of)
        scalar variance(s) of the noise in the Normal likelihood
        distribution of the model. If batched, the batch shape must be
        broadcastable with the shapes of all other batched parameters
        (`kernel.batch_shape`, `index_points`, etc.).
        Default value: `0.`
      jitter: `float` scalar `Tensor` added to the diagonal of the covariance
        matrix to ensure positive definiteness of the covariance matrix.
        Default value: `1e-6`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `False`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "GaussianProcess".

    Raises:
      ValueError: if `mean_fn` is not `None` and is not callable.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype(
                [index_points, observation_noise_variance, jitter], tf.float32)
            if index_points is not None:
                index_points = tf.convert_to_tensor(index_points,
                                                    dtype=dtype,
                                                    name='index_points')
            jitter = tf.convert_to_tensor(jitter, dtype=dtype, name='jitter')
            observation_noise_variance = tf.convert_to_tensor(
                observation_noise_variance,
                dtype=dtype,
                name='observation_noise_variance')

            self._kernel = kernel
            self._index_points = index_points
            # Default to a constant zero function, borrowing the dtype from
            # index_points to ensure consistency.
            if mean_fn is None:
                mean_fn = lambda x: tf.zeros([1], dtype=dtype)
            else:
                if not callable(mean_fn):
                    raise ValueError('`mean_fn` must be a Python callable')
            self._mean_fn = mean_fn
            self._observation_noise_variance = observation_noise_variance
            self._jitter = jitter

            graph_parents = [observation_noise_variance, jitter]
            if index_points is not None: graph_parents.append(index_points)

            with tf.name_scope('init'):
                super(GaussianProcess, self).__init__(
                    dtype=dtype,
                    reparameterization_type=reparameterization.
                    FULLY_REPARAMETERIZED,
                    validate_args=validate_args,
                    allow_nan_stats=allow_nan_stats,
                    parameters=parameters,
                    graph_parents=graph_parents,
                    name=name)
示例#27
0
 def ulp_fn(w):
     zeros = tf.zeros([x_size, 0])
     return model.unnormalized_log_prob(w, zeros)
示例#28
0
def _interp_regular_1d_grid_impl(x,
                                 x_ref_min,
                                 x_ref_max,
                                 y_ref,
                                 axis=-1,
                                 batch_y_ref=False,
                                 fill_value='constant_extension',
                                 fill_value_below=None,
                                 fill_value_above=None,
                                 grid_regularizing_transform=None,
                                 name=None):
    """1-D interpolation that works with/without batching."""
    # Note: we do *not* make the no-batch version a special case of the batch
    # version, because that would an inefficient use of batch_gather with
    # unnecessarily broadcast args.
    with tf.name_scope(name or 'interp_regular_1d_grid_impl'):

        # Arg checking.
        allowed_fv_st = ('constant_extension', 'extrapolate')
        for fv in (fill_value, fill_value_below, fill_value_above):
            if isinstance(fv, str) and fv not in allowed_fv_st:
                raise ValueError(
                    'A fill value ({}) was not an allowed string ({})'.format(
                        fv, allowed_fv_st))

        # Separate value fills for below/above incurs extra cost, so keep track of
        # whether this is needed.
        need_separate_fills = (
            fill_value_above is not None or fill_value_below is not None or
            fill_value == 'extrapolate'  # always requries separate below/above
        )
        if need_separate_fills and fill_value_above is None:
            fill_value_above = fill_value
        if need_separate_fills and fill_value_below is None:
            fill_value_below = fill_value

        dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
                                        dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, name='x', dtype=dtype)

        x_ref_min = tf.convert_to_tensor(x_ref_min,
                                         name='x_ref_min',
                                         dtype=dtype)
        x_ref_max = tf.convert_to_tensor(x_ref_max,
                                         name='x_ref_max',
                                         dtype=dtype)
        if not batch_y_ref:
            _assert_ndims_statically(x_ref_min, expect_ndims=0)
            _assert_ndims_statically(x_ref_max, expect_ndims=0)

        y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype)

        if batch_y_ref:
            # If we're batching,
            #   x.shape ~ [A1,...,AN, D],  x_ref_min/max.shape ~ [A1,...,AN]
            # So to add together we'll append a singleton.
            # If not batching, x_ref_min/max are scalar, so this isn't an issue,
            # moreover, if not batching, x can be scalar, and expanding x_ref_min/max
            # would cause a bad expansion of x when added to x (confused yet?).
            x_ref_min = x_ref_min[..., tf.newaxis]
            x_ref_max = x_ref_max[..., tf.newaxis]

        axis = tf.convert_to_tensor(axis, name='axis', dtype=tf.int32)
        axis = prefer_static.non_negative_axis(axis, tf.rank(y_ref))
        _assert_ndims_statically(axis, expect_ndims=0)

        ny = tf.cast(tf.shape(y_ref)[axis], dtype)

        # Map [x_ref_min, x_ref_max] to [0, ny - 1].
        # This is the (fractional) index of x.
        if grid_regularizing_transform is None:
            g = lambda x: x
        else:
            g = grid_regularizing_transform
        fractional_idx = ((g(x) - g(x_ref_min)) /
                          (g(x_ref_max) - g(x_ref_min)))
        x_idx_unclipped = fractional_idx * (ny - 1)

        # Wherever x is NaN, x_idx_unclipped will be NaN as well.
        # Keep track of the nan indices here (so we can impute NaN later).
        # Also eliminate any NaN indices, since there is not NaN in 32bit.
        nan_idx = tf.math.is_nan(x_idx_unclipped)
        zero = tf.zeros((), dtype=dtype)
        x_idx_unclipped = tf.where(nan_idx, zero, x_idx_unclipped)
        x_idx = tf.clip_by_value(x_idx_unclipped, zero, ny - 1)

        # Get the index above and below x_idx.
        # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx),
        # however, this results in idx_below == idx_above whenever x is on a grid.
        # This in turn results in y_ref_below == y_ref_above, and then the gradient
        # at this point is zero.  So here we 'jitter' one of idx_below, idx_above,
        # so that they are at different values.  This jittering does not affect the
        # interpolated value, but does make the gradient nonzero (unless of course
        # the y_ref values are the same).
        idx_below = tf.floor(x_idx)
        idx_above = tf.minimum(idx_below + 1, ny - 1)
        idx_below = tf.maximum(idx_above - 1, 0)

        # These are the values of y_ref corresponding to above/below indices.
        idx_below_int32 = tf.cast(idx_below, dtype=tf.int32)
        idx_above_int32 = tf.cast(idx_above, dtype=tf.int32)
        if batch_y_ref:
            # If y_ref.shape ~ [A1,...,AN, C, B1,...,BN],
            # and x.shape, x_ref_min/max.shape ~ [A1,...,AN, D]
            # Then y_ref_below.shape ~ [A1,...,AN, D, B1,...,BN]
            y_ref_below = _batch_gather_with_broadcast(y_ref, idx_below_int32,
                                                       axis)
            y_ref_above = _batch_gather_with_broadcast(y_ref, idx_above_int32,
                                                       axis)
        else:
            # Here, y_ref_below.shape =
            #   y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:]
            y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis)
            y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis)

        # Use t to get a convex combination of the below/above values.
        t = x_idx - idx_below

        # x, and tensors shaped like x, need to be added to, and selected with
        # (using tf.where) the output y.  This requires appending singletons.
        # Make functions appropriate for batch/no-batch.
        if batch_y_ref:
            # In the non-batch case, the output shape is going to be
            #   y_ref.shape[:axis] + x.shape + y_ref.shape[axis+1:]
            expand_x_fn = _make_expand_x_fn_for_batch_interpolation(
                y_ref, axis)
        else:
            # In the batch case, the output shape is going to be
            #   Broadcast(y_ref.shape[:axis], x.shape[:-1]) +
            #   x.shape[-1:] +  y_ref.shape[axis+1:]
            expand_x_fn = _make_expand_x_fn_for_non_batch_interpolation(
                y_ref, axis)

        t = expand_x_fn(t)
        nan_idx = expand_x_fn(nan_idx, broadcast=True)
        x_idx_unclipped = expand_x_fn(x_idx_unclipped, broadcast=True)

        y = t * y_ref_above + (1 - t) * y_ref_below

        # Now begins a long excursion to fill values outside [x_min, x_max].

        # Re-insert NaN wherever x was NaN.
        y = tf.where(nan_idx, tf.constant(np.nan, y.dtype), y)

        if not need_separate_fills:
            if fill_value == 'constant_extension':
                pass  # Already handled by clipping x_idx_unclipped.
            else:
                y = tf.where(
                    (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1),
                    fill_value, y)
        else:
            # Fill values below x_ref_min <==> x_idx_unclipped < 0.
            if fill_value_below == 'constant_extension':
                pass  # Already handled by the clipping that created x_idx_unclipped.
            elif fill_value_below == 'extrapolate':
                if batch_y_ref:
                    # For every batch member, gather the first two elements of y across
                    # `axis`.
                    y_0 = tf.gather(y_ref, [0], axis=axis)
                    y_1 = tf.gather(y_ref, [1], axis=axis)
                else:
                    # If not batching, we want to gather the first two elements, just like
                    # above.  However, these results need to be replicated for every
                    # member of x.  An easy way to do that is to gather using
                    # indices = zeros/ones(x.shape).
                    y_0 = tf.gather(y_ref,
                                    tf.zeros(tf.shape(x), dtype=tf.int32),
                                    axis=axis)
                    y_1 = tf.gather(y_ref,
                                    tf.ones(tf.shape(x), dtype=tf.int32),
                                    axis=axis)
                x_delta = (x_ref_max - x_ref_min) / (ny - 1)
                x_factor = expand_x_fn((x - x_ref_min) / x_delta,
                                       broadcast=True)
                y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0),
                             y)
            else:
                y = tf.where(x_idx_unclipped < 0, fill_value_below, y)
            # Fill values above x_ref_min <==> x_idx_unclipped > ny - 1.
            if fill_value_above == 'constant_extension':
                pass  # Already handled by the clipping that created x_idx_unclipped.
            elif fill_value_above == 'extrapolate':
                ny_int32 = tf.shape(y_ref)[axis]
                if batch_y_ref:
                    y_n1 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 1],
                                     axis=axis)
                    y_n2 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 2],
                                     axis=axis)
                else:
                    y_n1 = tf.gather(y_ref,
                                     tf.fill(tf.shape(x), ny_int32 - 1),
                                     axis=axis)
                    y_n2 = tf.gather(y_ref,
                                     tf.fill(tf.shape(x), ny_int32 - 2),
                                     axis=axis)
                x_delta = (x_ref_max - x_ref_min) / (ny - 1)
                x_factor = expand_x_fn((x - x_ref_max) / x_delta,
                                       broadcast=True)
                y = tf.where(x_idx_unclipped > ny - 1,
                             y_n1 + x_factor * (y_n1 - y_n2), y)
            else:
                y = tf.where(x_idx_unclipped > ny - 1, fill_value_above, y)

        return y
示例#29
0
    def set_model(self, model):
        """Sets Keras model and creates summary ops."""

        self.model = model
        self._init_writer(model)
        # histogram summaries only enabled in graph mode
        if not tf.executing_eagerly():
            self._make_histogram_ops(model)
            self.merged = tf.compat.v1.summary.merge_all()

        # If both embedding_freq and embeddings_data are available, we will
        # visualize embeddings.
        if self.embeddings_freq and self.embeddings_data is not None:
            # Avoid circular dependency.
            from keras.engine import (
                training_utils_v1, )  # pylint: disable=g-import-not-at-top

            self.embeddings_data = training_utils_v1.standardize_input_data(
                self.embeddings_data, model.input_names)

            # If embedding_layer_names are not provided, get all of the embedding
            # layers from the model.
            embeddings_layer_names = self.embeddings_layer_names
            if not embeddings_layer_names:
                embeddings_layer_names = [
                    layer.name for layer in self.model.layers
                    if type(layer).__name__ == "Embedding"
                ]

            self.assign_embeddings = []
            embeddings_vars = {}

            self.batch_id = batch_id = tf.compat.v1.placeholder(tf.int32)
            self.step = step = tf.compat.v1.placeholder(tf.int32)

            for layer in self.model.layers:
                if layer.name in embeddings_layer_names:
                    embedding_input = self.model.get_layer(layer.name).output
                    embedding_size = np.prod(embedding_input.shape[1:])
                    embedding_input = tf.reshape(embedding_input,
                                                 (step, int(embedding_size)))
                    shape = (
                        self.embeddings_data[0].shape[0],
                        int(embedding_size),
                    )
                    embedding = tf.Variable(tf.zeros(shape),
                                            name=layer.name + "_embedding")
                    embeddings_vars[layer.name] = embedding
                    batch = tf.compat.v1.assign(
                        embedding[batch_id:batch_id + step], embedding_input)
                    self.assign_embeddings.append(batch)

            self.saver = tf.compat.v1.train.Saver(
                list(embeddings_vars.values()))

            # Create embeddings_metadata dictionary
            if isinstance(self.embeddings_metadata, str):
                embeddings_metadata = {
                    layer_name: self.embeddings_metadata
                    for layer_name in embeddings_vars.keys()
                }
            else:
                # If embedding_metadata is already a dictionary
                embeddings_metadata = self.embeddings_metadata

            try:
                from tensorboard.plugins import projector
            except ImportError:
                raise ImportError(
                    "Failed to import TensorBoard. Please make sure that "
                    'TensorBoard integration is complete."')

            # TODO(psv): Add integration tests to test embedding visualization
            # with TensorBoard callback. We are unable to write a unit test for this
            # because TensorBoard dependency assumes TensorFlow package is installed.
            config = projector.ProjectorConfig()
            for layer_name, tensor in embeddings_vars.items():
                embedding = config.embeddings.add()
                embedding.tensor_name = tensor.name

                if (embeddings_metadata is not None
                        and layer_name in embeddings_metadata):
                    embedding.metadata_path = embeddings_metadata[layer_name]

            projector.visualize_embeddings(self.writer, config)
 def testNonEmptyConstantTensor(self):
     x = tf.zeros([2, 3, 4])
     rank = distribution_util.prefer_static_rank(x)
     if not tf.executing_eagerly():
         self.assertIsInstance(rank, np.ndarray)
     self.assertEqual(3, rank)