Пример #1
0
    def _build_sub_tree(self,
                        directions,
                        integrator,
                        current_step_meta_info,
                        nsteps,
                        initial_state,
                        continue_tree,
                        not_divergence,
                        momentum_state_memory,
                        seed,
                        name=None):
        with tf.name_scope('build_sub_tree'):
            batch_shape = ps.shape(current_step_meta_info.init_energy)
            # We never want to select the inital state
            if MULTINOMIAL_SAMPLE:
                init_weight = tf.fill(
                    batch_shape,
                    tf.constant(
                        -np.inf,
                        dtype=current_step_meta_info.init_energy.dtype))
            else:
                init_weight = tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE)

            init_momentum_cumsum = [
                tf.zeros_like(x) for x in initial_state.momentum
            ]
            initial_state_candidate = TreeDoublingStateCandidate(
                state=initial_state.state,
                target=initial_state.target,
                target_grad_parts=initial_state.target_grad_parts,
                energy=initial_state.target,
                weight=init_weight)
            energy_diff_sum = tf.zeros_like(current_step_meta_info.init_energy,
                                            name='energy_diff_sum')
            [
                _,
                _,
                energy_diff_tree_sum,
                momentum_tree_cumsum,
                leapfrogs_taken,
                final_state,
                candidate_tree_state,
                final_continue_tree,
                final_not_divergence,
                momentum_state_memory,
            ] = tf.while_loop(
                cond=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum,  # pylint: disable=g-long-lambda
                leapfrogs_taken, state, state_c, continue_tree, not_divergence,
                momentum_state_memory: (
                    (iter_ < nsteps) & tf.reduce_any(continue_tree)),
                body=lambda iter_, seed, energy_diff_sum, init_momentum_cumsum,  # pylint: disable=g-long-lambda
                leapfrogs_taken, state, state_c, continue_tree, not_divergence,
                momentum_state_memory: (self._loop_build_sub_tree(
                    directions, integrator, current_step_meta_info, iter_,
                    energy_diff_sum, init_momentum_cumsum, leapfrogs_taken,
                    state, state_c, continue_tree, not_divergence,
                    momentum_state_memory, seed)),
                loop_vars=(
                    tf.zeros([], dtype=tf.int32, name='iter'),
                    seed,
                    energy_diff_sum,
                    init_momentum_cumsum,
                    tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE),
                    initial_state,
                    initial_state_candidate,
                    continue_tree,
                    not_divergence,
                    momentum_state_memory,
                ),
                parallel_iterations=self.parallel_iterations)

        return (
            candidate_tree_state,
            final_state,
            final_not_divergence,
            final_continue_tree,
            energy_diff_tree_sum,
            momentum_tree_cumsum,
            leapfrogs_taken,
        )
Пример #2
0
def sample_sequential_monte_carlo(
        prior_log_prob_fn,
        likelihood_log_prob_fn,
        current_state,
        min_num_steps=2,
        max_num_steps=25,
        max_stage=100,
        make_kernel_fn=make_rwmh_kernel_fn,
        tuning_fn=simple_heuristic_tuning,
        make_tempered_target_log_prob_fn=default_make_tempered_target_log_prob_fn,
        resample_fn=weighted_resampling.resample_systematic,
        ess_threshold_ratio=0.5,
        parallel_iterations=10,
        seed=None,
        name=None):
    """Runs Sequential Monte Carlo to sample from the posterior distribution.

  This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
  to sample from a series of distributions that slowly interpolates between
  an initial 'prior' distribution:

    `exp(prior_log_prob_fn(x))`

  and the target 'posterior' distribution:

    `exp(prior_log_prob_fn(x) + target_log_prob_fn(x))`,

  by mutating a collection of MC samples (i.e., particles). The approach is also
  known as Particle Filter in some literature. The current implemenetation is
  largely based on Del Moral et al [1], which adapts the tempering sequence
  adaptively (base on the effective sample size) and the scaling of the mutation
  kernel (base on the sample covariance of the particles) at each stage.

  Args:
    prior_log_prob_fn: Python callable that returns the log density of the
      prior distribution.
    likelihood_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the likelihood distribution.
    current_state: Nested structure of `Tensor`s, each of shape
      `concat([[num_particles, b1, ..., bN], latent_part_event_shape])`, where
      `b1, ..., bN` are optional batch dimensions. Each batch represents an
      independent SMC run.
    min_num_steps: The minimal number of kernel transition steps in one mutation
      of the MC samples.
    max_num_steps: The maximum number of kernel transition steps in one mutation
      of the MC samples. Note that the actual number of steps in one mutation is
      tuned during sampling and likely lower than the max_num_step.
    max_stage: Integer number of the stage for increasing the temperature
      from 0 to 1.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
        `sample_sequential_monte_carlo` creates a new `target_log_prob_fn` which
        is an interpolation between the supplied `target_log_prob_fn` and
        `proposal_log_prob_fn`; it is this interpolated function which is used
        as an argument to `make_kernel_fn`.
    tuning_fn: Python `callable` which takes the number of steps, the log
      scaling, and the log acceptance ratio from the last mutation and output
      the number of steps and log scaling for the next mutation.
    make_tempered_target_log_prob_fn: Python `callable` that takes the
      `prior_log_prob_fn`, `likelihood_log_prob_fn`, and `inverse_temperatures`
      and creates a `target_log_prob_fn` `callable` that pass to
      `make_kernel_fn`.
    resample_fn: Python `callable` to generate the indices of resampled
      particles, given their weights. Generally, one of
      `tfp.experimental.mcmc.resample_independent` or
      `tfp.experimental.mcmc.resample_systematic`, or any function with the same
      signature.
      Default value: `tfp.experimental.mcmc.resample_systematic`.
    ess_threshold_ratio: Target ratio for effective sample size.
    parallel_iterations: The number of iterations allowed to run in parallel. It
      must be a positive integer. See `tf.while_loop` for more details.
    seed: Python integer or TFP seedstream to seed the random number generator.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'sample_sequential_monte_carlo').

  Returns:
    n_stage: Number of the mutation stage SMC ran.
    final_state: `Tensor` or Python `list` of `Tensor`s representing the
      final state(s) of the Markov chain(s). The output are the posterior
      samples.
    final_kernel_results: `collections.namedtuple` of internal calculations used
      to advance the chain.

  #### References

  [1] Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. An adaptive sequential
      Monte Carlo method for approximate Bayesian computation.
      _Statistics and Computing_, 22.5(1009-1020), 2012.

  """

    with tf.name_scope(name or 'sample_sequential_monte_carlo'):
        is_seeded = seed is not None
        seed = samplers.sanitize_seed(seed, salt='mcmc.sample_smc')

        unwrap_state_list = not tf.nest.is_nested(current_state)
        if unwrap_state_list:
            current_state = [current_state]
        current_state = [
            tf.convert_to_tensor(s, dtype_hint=tf.float32)
            for s in current_state
        ]

        # Initial preprocessing at Stage 0
        likelihood_log_prob = likelihood_log_prob_fn(*current_state)

        likelihood_rank = ps.rank(likelihood_log_prob)
        dimension = ps.reduce_sum([
            ps.reduce_prod(ps.shape(x)[likelihood_rank:])
            for x in current_state
        ])

        # We infer the particle shapes from the resulting likelihood:
        # [num_particles, b1, ..., bN]
        particle_shape = ps.shape(likelihood_log_prob)
        num_particles, batch_shape = particle_shape[0], particle_shape[1:]
        effective_sample_size_threshold = tf.cast(
            num_particles * ess_threshold_ratio, tf.int32)

        # TODO(b/152412213): Revisit this default parameter.
        # Default to the optimal scaling of a random walk kernel for a d-dimensional
        # normal distributed targets: 2.38 ** 2 / d.
        # For more detail see:
        # Roberts GO, Gelman A, Gilks WR. Weak convergence and optimal scaling of
        # random walk Metropolis algorithms. _The annals of applied probability_.
        # 1997;7(1):110-20.
        scale_start = (tf.constant(2.38**2, dtype=likelihood_log_prob.dtype) /
                       tf.constant(dimension, dtype=likelihood_log_prob.dtype))

        inverse_temperature = tf.zeros(batch_shape,
                                       dtype=likelihood_log_prob.dtype)
        scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(
            scale_start, 1.)
        kernel = make_kernel_fn(
            make_tempered_target_log_prob_fn(prior_log_prob_fn,
                                             likelihood_log_prob_fn,
                                             inverse_temperature),
            current_state, scalings)
        pkr = kernel.bootstrap_results(current_state)
        _, kernel_target_log_prob = gather_mh_like_result(pkr)

        particle_info = ParticleInfo(
            log_accept_prob=ps.zeros_like(likelihood_log_prob),
            log_scalings=tf.math.log(scalings),
            tempered_log_prob=kernel_target_log_prob,
            likelihood_log_prob=likelihood_log_prob,
        )

        current_pkr = SMCResults(
            num_steps=tf.convert_to_tensor(max_num_steps,
                                           dtype=tf.int32,
                                           name='num_steps'),
            inverse_temperature=inverse_temperature,
            log_marginal_likelihood=tf.zeros_like(inverse_temperature),
            particle_info=particle_info)

        def update_weights_temperature(inverse_temperature,
                                       likelihood_log_prob):
            """Calculate the next inverse temperature and update weights."""
            likelihood_diff = likelihood_log_prob - tf.reduce_max(
                likelihood_log_prob, axis=0)

            def _body_fn(new_beta, upper_beta, lower_beta, eff_size,
                         log_weights):
                """One iteration of the temperature and weight update."""
                new_beta = (lower_beta + upper_beta) / 2.0
                log_weights = (new_beta -
                               inverse_temperature) * likelihood_diff
                log_weights_norm = tf.math.log_softmax(log_weights, axis=0)
                eff_size = tf.cast(
                    tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm,
                                                     axis=0)), tf.int32)
                upper_beta = tf.where(
                    eff_size < effective_sample_size_threshold, new_beta,
                    upper_beta)
                lower_beta = tf.where(
                    eff_size < effective_sample_size_threshold, lower_beta,
                    new_beta)
                return new_beta, upper_beta, lower_beta, eff_size, log_weights

            def _cond_fn(new_beta, upper_beta, lower_beta, eff_size, *_):  # pylint: disable=unused-argument
                # TODO(junpenglao): revisit threshold below to be dtype specific.
                threshold = 1e-6
                return (tf.math.reduce_any(upper_beta - lower_beta > threshold)
                        & tf.math.reduce_any(
                            eff_size != effective_sample_size_threshold))

            (new_beta, upper_beta, lower_beta, eff_size,
             log_weights) = tf.while_loop(  # pylint: disable=unused-variable
                 cond=_cond_fn,
                 body=_body_fn,
                 loop_vars=(tf.zeros_like(inverse_temperature),
                            tf.fill(ps.shape(inverse_temperature),
                                    tf.constant(2, inverse_temperature.dtype)),
                            inverse_temperature,
                            tf.zeros_like(inverse_temperature, dtype=tf.int32),
                            tf.zeros_like(likelihood_diff)),
                 parallel_iterations=parallel_iterations)

            log_weights = tf.where(new_beta < 1., log_weights,
                                   (1. - inverse_temperature) *
                                   likelihood_diff)
            marginal_loglike_ = reduce_logmeanexp(
                (new_beta - inverse_temperature) * likelihood_log_prob, axis=0)
            new_inverse_temperature = tf.clip_by_value(new_beta, 0., 1.)

            return marginal_loglike_, new_inverse_temperature, log_weights

        def mutate(current_state, log_scalings, num_steps,
                   inverse_temperature):
            """Mutate the state using a Transition kernel."""
            with tf.name_scope('mutate_states'):
                scalings = tf.exp(log_scalings)
                kernel = make_kernel_fn(
                    make_tempered_target_log_prob_fn(prior_log_prob_fn,
                                                     likelihood_log_prob_fn,
                                                     inverse_temperature),
                    current_state, scalings)
                pkr = kernel.bootstrap_results(current_state)
                kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)

                def mutate_onestep(i, seed, state, pkr, log_accept_prob_sum):
                    iter_seed, next_seed = (samplers.split_seed(seed)
                                            if is_seeded else (None, seed))

                    one_step_kwargs = dict(seed=iter_seed) if is_seeded else {}
                    next_state, next_kernel_results = kernel.one_step(
                        state, pkr, **one_step_kwargs)
                    kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)
                    log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.)
                    log_accept_prob_sum = log_add_exp(log_accept_prob_sum,
                                                      log_accept_prob)
                    return [
                        i + 1, next_seed, next_state, next_kernel_results,
                        log_accept_prob_sum
                    ]

                (
                    _, _, next_state, next_kernel_results, log_accept_prob_sum
                ) = tf.while_loop(
                    cond=lambda i, *args: i < num_steps,
                    body=mutate_onestep,
                    loop_vars=(
                        tf.zeros([], dtype=tf.int32),
                        seed,
                        current_state,
                        pkr,
                        # we accumulate the acceptance probability in log space.
                        tf.fill(
                            ps.shape(kernel_log_accept_ratio),
                            tf.constant(-np.inf,
                                        kernel_log_accept_ratio.dtype))),
                    parallel_iterations=parallel_iterations)
                _, kernel_target_log_prob = gather_mh_like_result(
                    next_kernel_results)
                avg_log_accept_prob_per_particle = log_accept_prob_sum - tf.math.log(
                    tf.cast(num_steps + 1, log_accept_prob_sum.dtype))
                return (next_state, avg_log_accept_prob_per_particle,
                        kernel_target_log_prob)

        # One SMC steps.
        def smc_body_fn(stage, state, smc_kernel_result):
            """Run one stage of SMC with constant temperature."""
            (new_marginal, new_inv_temperature,
             log_weights) = update_weights_temperature(
                 smc_kernel_result.inverse_temperature,
                 smc_kernel_result.particle_info.likelihood_log_prob)
            # TODO(b/152412213) Use a tf.scan to better collect debug info.
            if PRINT_DEBUG:
                tf.print(
                    'Stage:', stage, 'Beta:', new_inv_temperature, 'n_steps:',
                    smc_kernel_result.num_steps, 'accept:',
                    tf.exp(
                        reduce_logmeanexp(
                            smc_kernel_result.particle_info.log_accept_prob,
                            axis=0)), 'scaling:',
                    tf.exp(
                        reduce_logmeanexp(
                            smc_kernel_result.particle_info.log_scalings,
                            axis=0)))
            (resampled_state,
             resampled_particle_info), _, _ = weighted_resampling.resample(
                 particles=(state, smc_kernel_result.particle_info),
                 log_weights=log_weights,
                 resample_fn=resample_fn,
                 seed=seed)
            next_num_steps, next_log_scalings = tuning_fn(
                smc_kernel_result.num_steps,
                resampled_particle_info.log_scalings,
                resampled_particle_info.log_accept_prob)
            # Skip tuning at stage 0.
            next_num_steps = tf.where(stage == 0, smc_kernel_result.num_steps,
                                      next_num_steps)
            next_log_scalings = tf.where(stage == 0,
                                         resampled_particle_info.log_scalings,
                                         next_log_scalings)
            next_num_steps = tf.clip_by_value(next_num_steps, min_num_steps,
                                              max_num_steps)

            next_state, log_accept_prob, tempered_log_prob = mutate(
                resampled_state, next_log_scalings, next_num_steps,
                new_inv_temperature)
            next_pkr = SMCResults(
                num_steps=next_num_steps,
                inverse_temperature=new_inv_temperature,
                log_marginal_likelihood=(
                    new_marginal + smc_kernel_result.log_marginal_likelihood),
                particle_info=ParticleInfo(
                    log_accept_prob=log_accept_prob,
                    log_scalings=next_log_scalings,
                    tempered_log_prob=tempered_log_prob,
                    likelihood_log_prob=likelihood_log_prob_fn(*next_state),
                ))
            return stage + 1, next_state, next_pkr

        (n_stage, final_state, final_kernel_results) = tf.while_loop(
            cond=lambda i, state, pkr: (  # pylint: disable=g-long-lambda
                (i < max_stage) & tf.reduce_any(pkr.inverse_temperature < 1.)),
            body=smc_body_fn,
            loop_vars=(tf.zeros([],
                                dtype=tf.int32), current_state, current_pkr),
            parallel_iterations=parallel_iterations)
        if unwrap_state_list:
            final_state = final_state[0]
        return n_stage, final_state, final_kernel_results
Пример #3
0
def trace_scan(loop_fn,
               initial_state,
               elems,
               trace_fn,
               trace_criterion_fn=None,
               static_trace_allocation_size=None,
               parallel_iterations=10,
               name=None):
  """A simplified version of `tf.scan` that has configurable tracing.

  This function repeatedly calls `loop_fn(state, elem)`, where `state` is the
  `initial_state` during the first iteration, and the return value of `loop_fn`
  for every iteration thereafter. `elem` is a slice of `elements` along the
  first dimension, accessed in order. Additionally, it calls `trace_fn` on the
  return value of `loop_fn`. The `Tensor`s in return values of `trace_fn` are
  stacked and returned from this function, such that the first dimension of
  those `Tensor`s matches the size of `elems`.

  Args:
    loop_fn: A callable that takes in a `Tensor` or a nested collection of
      `Tensor`s with the same structure as `initial_state`, a slice of `elems`
      and returns the same structure as `initial_state`.
    initial_state: A `Tensor` or a nested collection of `Tensor`s passed to
      `loop_fn` in the first iteration.
    elems: A `Tensor` that is split along the first dimension and each element
      of which is passed to `loop_fn`.
    trace_fn: A callable that takes in the return value of `loop_fn` and returns
      a `Tensor` or a nested collection of `Tensor`s.
    trace_criterion_fn: Optional callable that takes in the return value of
      `loop_fn` and returns a boolean `Tensor` indicating whether to trace it.
      If `None`, all steps are traced.
      Default value: `None`.
    static_trace_allocation_size: Optional Python `int` size of trace to
      allocate statically. This should be an upper bound on the number of steps
      traced and is used only when the length cannot be
      statically inferred (for example, if a `trace_criterion_fn` is specified).
      It is primarily intended for contexts where static shapes are required,
      such as in XLA-compiled code.
      Default value: `None`.
    parallel_iterations: Passed to the internal `tf.while_loop`.
    name: Name scope used in this function. Default: 'trace_scan'.

  Returns:
    final_state: The final return value of `loop_fn`.
    trace: The same structure as the return value of `trace_fn`, but with each
      `Tensor` being a stack of the corresponding `Tensors` in the return value
      of `trace_fn` for each slice of `elems`.
  """
  with tf.name_scope(name or 'trace_scan'), tf1.variable_scope(
      tf1.get_variable_scope()) as vs:
    if vs.caching_device is None and not tf.executing_eagerly():
      vs.set_caching_device(lambda op: op.device)

    initial_state = tf.nest.map_structure(
        lambda x: tf.convert_to_tensor(x, name='initial_state'),
        initial_state, expand_composites=True)
    elems = tf.convert_to_tensor(elems, name='elems')

    length = ps.size0(elems)

    # This is an TensorArray in part because of XLA, which had trouble with
    # non-statically known indices. I.e. elems[i] errored, but
    # elems_array.read(i) worked.
    elems_array = tf.TensorArray(
        elems.dtype, size=length, element_shape=elems.shape[1:])
    elems_array = elems_array.unstack(elems)

    # Initialize trace arrays.
    if trace_criterion_fn is None:
      dynamic_size, initial_size = tf.is_tensor(length), length
    elif static_trace_allocation_size is not None:
      dynamic_size, initial_size = False, static_trace_allocation_size
    elif JAX_MODE or (not tf.executing_eagerly() and
                      control_flow_util.GraphOrParentsInXlaContext(
                          tf1.get_default_graph())):
      dynamic_size, initial_size = False, length
    else:
      dynamic_size, initial_size = True, 0
    initial_trace = trace_fn(initial_state)
    flat_initial_trace = tf.nest.flatten(initial_trace, expand_composites=True)
    trace_arrays = []
    for trace_elt in flat_initial_trace:
      trace_arrays.append(
          tf.TensorArray(
              trace_elt.dtype,
              size=initial_size,
              dynamic_size=dynamic_size,
              element_shape=trace_elt.shape))

    # Helper for writing a (structured) state to (structured) arrays.
    def trace_one_step(num_steps_traced, trace_arrays, state):
      return [ta.write(num_steps_traced, x) for ta, x in
              zip(trace_arrays,
                  tf.nest.flatten(trace_fn(state), expand_composites=True))]

    def _body(i, state, num_steps_traced, trace_arrays):
      elem = elems_array.read(i)
      state = loop_fn(state, elem)

      trace_arrays, num_steps_traced = ps.cond(
          trace_criterion_fn(state) if trace_criterion_fn else True,
          lambda: (trace_one_step(num_steps_traced, trace_arrays, state),  # pylint: disable=g-long-lambda
                   num_steps_traced + 1),
          lambda: (trace_arrays, num_steps_traced))

      return i + 1, state, num_steps_traced, trace_arrays

    _, final_state, _, trace_arrays = tf.while_loop(
        cond=lambda i, *_: i < length,
        body=_body,
        loop_vars=(0, initial_state, 0, trace_arrays),
        parallel_iterations=parallel_iterations)

    # unflatten
    stacked_trace = tf.nest.pack_sequence_as(
        initial_trace, [ta.stack() for ta in trace_arrays],
        expand_composites=True)

    # Restore the static length if we know it.
    static_length = tf.TensorShape(None if dynamic_size else initial_size)

    def _merge_static_length(x):
      tensorshape_util.set_shape(x, static_length.concatenate(x.shape[1:]))
      return x

    stacked_trace = tf.nest.map_structure(
        _merge_static_length, stacked_trace, expand_composites=True)
    return final_state, stacked_trace
Пример #4
0
    def _sample_n(self, n, seed=None):
        power = tf.convert_to_tensor(self.power)
        shape = tf.concat([[n], tf.shape(power)], axis=0)

        seed = samplers.sanitize_seed(seed, salt='zipf')

        minval_u = self._hat_integral(0.5, power=power) + 1.
        maxval_u = self._hat_integral(dtype_util.max(tf.int64) - 0.5,
                                      power=power)

        def loop_body(should_continue, k, seed):
            """Resample the non-accepted points."""
            u_seed, next_seed = samplers.split_seed(seed)
            # The range of U is chosen so that the resulting sample K lies in
            # [0, tf.int64.max). The final sample, if accepted, is K + 1.
            u = samplers.uniform(shape,
                                 minval=minval_u,
                                 maxval=maxval_u,
                                 dtype=power.dtype,
                                 seed=u_seed)
            # set_shape needed here because of b/139013403
            tensorshape_util.set_shape(u, should_continue.shape)

            # Sample the point X from the continuous density h(x) \propto x^(-power).
            x = self._hat_integral_inverse(u, power=power)

            # Rejection-inversion requires a `hat` function, h(x) such that
            # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
            # support. A natural hat function for us is h(x) = x^(-power).
            #
            # After sampling X from h(x), suppose it lies in the interval
            # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
            # if lies to the left of x_K, where x_K is defined by:
            #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
            # where H(x) = \int_x^inf h(x) dx.

            # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
            # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
            # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

            # Update the non-accepted points.
            # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
            k = tf.where(should_continue, tf.floor(x + 0.5), k)
            accept = (u <= self._hat_integral(k + .5, power=power) +
                      tf.exp(self._log_prob(k + 1, power=power)))

            return [should_continue & (~accept), k, next_seed]

        should_continue, samples, _ = tf.while_loop(
            cond=lambda should_continue, *ignore: tf.reduce_any(should_continue
                                                                ),
            body=loop_body,
            loop_vars=[
                tf.ones(shape, dtype=tf.bool),  # should_continue
                tf.zeros(shape, dtype=power.dtype),  # k
                seed,  # seed
            ],
            maximum_iterations=self.sample_maximum_iterations,
        )
        samples = samples + 1.

        if self.validate_args and dtype_util.is_integer(self.dtype):
            samples = distribution_util.embed_check_integer_casting_closed(
                samples, target_dtype=self.dtype, assert_positive=True)

        samples = tf.cast(samples, self.dtype)

        if self.validate_args:
            npdt = dtype_util.as_numpy_dtype(self.dtype)
            v = npdt(
                dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan
            )
            samples = tf.where(should_continue, v, samples)

        return samples
Пример #5
0
def bessel_iv_ratio(v, z, name=None):
    """Computes `I_{v} (z) / I_{v - 1} (z)` in a numerically stable way.

  Let I(v, z) be the modified bessel function of the first kind. This computes
  the ratio of I(v, z) / I(v - 1, z). This can be more numerically stable
  and faster than computing the ratio directly.

  This uses a continued fraction approximation attributed to Gauss for
  computing this quantity in the limit where z <= v, and a continued fraction
  approximation attributed to Perron for z > v.

  Args:
    v: value for which `I_{v}(z) / I_{v - 1}(z)` should be computed. Expect
      v > 0.
    z: value for which `I_{v}(z) / I_{v - 1}(z)` should be computed. Expect
      z > 0.
    name: A name for the operation (optional).
      Default value: `None` (i.e., 'bessel_iv_ratio').

  Returns:
    I(v, z) / I(v - 1, z).

  #### References
  [1]: Walter Gautschi and Josef Slavik. On the Computation of Modified
       Bessel Function Ratios. http://www.jstor.com/stable/2006491
  """
    with tf.name_scope(name or 'bessel_iv_ratio'):
        dtype = dtype_util.common_dtype([v, z], tf.float32)
        v = tf.convert_to_tensor(v, dtype=dtype)
        z = tf.convert_to_tensor(z, dtype=dtype)

        np_finfo = np.finfo(dtype_util.as_numpy_dtype(dtype))
        tolerance = tf.cast(np_finfo.resolution, dtype=dtype)

        safe_to_use_perron = z > v

        def gauss_term_fn(iteration_count, v, z):
            """Terms for the Gauss continued fraction."""
            return tf.math.square(z) / 4. / ((v + iteration_count - 1) *
                                             (v + iteration_count))

        # The Gauss continued fraction converges faster for z < v.
        # For z > v, set z to something much less than v.
        safe_z_less_v = tf.where(safe_to_use_perron, v / 1000., z)

        # We use forward recurrence for the Gauss continued fraction.
        # This is so that we can do early termination.
        # There are a few reasons why this doesn't overflow:
        # * All partial numerators / denominators are positive.
        # * Partial numerators approach zero as 1 / n**2, where
        #   n is the iteration count.
        # * All partial numerators are less than 1.
        # Combined with the recurrence, this ensures no overflow.
        # as the number of iterations -> infinity.
        gauss_cf = _compute_general_continued_fraction(
            # Use a max of 200 steps. Almost always we will be much less
            # than this.
            200,
            [v, safe_z_less_v],
            tolerance=tolerance,
            partial_numerator_fn=gauss_term_fn)
        # Add the zeroth term for the Gauss continued fraction.
        gauss_cf = tf.math.reciprocal((1. + gauss_cf) * 2. * v / z)

        # For the Perron CF we use the backward recurrence. This is because
        # generally the backward recurrence is more numerically stable
        # than forward recurrence, especially with negative terms.
        # We use a flat 50 steps. Anecdotally, for z > v, convergence is
        # much faster than that.

        # The Perron continued fraction converges much faster for z >> v.
        # For z < v, set z to something much greater than v.
        safe_z_greater_v = tf.where(~safe_to_use_perron, 1000. * v, z)

        def perron_term_fn(iteration_count, v, z):
            """Terms for the Perron continued fraction."""
            return -0.5 * z * (v + iteration_count - 0.5) / (
                (v + z +
                 (iteration_count - 1.) / 2.) * (v + z + iteration_count / 2.))

        total_perron_iteration_count = 50

        def _backward_cf_one_step(iteration_count, cf):
            cf = perron_term_fn(total_perron_iteration_count - iteration_count,
                                v, safe_z_greater_v) / (1. + cf)
            return [iteration_count + 1., cf]

        # For the Perron CF, we omit the first numerator because it
        # has a different form.

        _, perron_cf = tf.while_loop(
            cond=lambda i, _: i < total_perron_iteration_count - 1,
            body=_backward_cf_one_step,
            # Use 50 iterations. Empirically, the Perron continued fraction
            # converges much faster than this.
            loop_vars=[
                tf.cast(0., dtype=dtype),
                tf.zeros_like(safe_z_greater_v)
            ])
        first_term = -0.5 * z * (v + 0.5) / ((v + z / 2.) * (v + z + 0.5))

        perron_cf = first_term / (1. + perron_cf)

        # Add the zeroth term for the Perron continued fraction.
        perron_zeroth_term = (z + 2 * v) / z
        perron_cf = tf.math.reciprocal(perron_zeroth_term * (1. + perron_cf))
        result = tf.where(safe_to_use_perron, perron_cf, gauss_cf)

        def grad(dy):
            """Computes the derivative of the ratio elementwise with respect to z.

      For shorthand, let `I(v) = I(v, z)`, `R(v) = I(v, z) / I(v - 1, z)`

      ```
      R'(v) = (I'(v)I(v - 1) - I(v)I'(v - 1)) / I(v - 1) ** 2
             = 0.5 * ((I(v - 1) + I(v + 1))I(v - 1) - I(v)(
                  I(v) + I(v - 2))) / I(v - 1) ** 2
             = 0.5 * (1. + I(v + 1) / I(v - 1) - (I(v) / I(v - 1)) ** 2 - (
                  I(v) / I(v - 1)) * (I(v - 2) / I(v - 1)))
             = 0.5 * (1. + R(v + 1) * R(v) - R(v) ** 2 - R(v) / R(v - 1))
             = 0.5 * (1. + R(v) * (R(v + 1) - R(v) - 1. / R(v - 1)))
      ```
      To avoid computing R(v - 1) when v <= 1 (which is not valid),
      we can rewrite `I(v - 2) = 2 (v - 1) / z * I(v - 1) + I(v)`.
      Thus the last term becomes:
      ```
      -1. / R(v - 1) = -I(v - 2) / I(v - 1) = -2 (v - 1) / z - R(v)
      ```

      Args:
        dy: A Tensor with type `float32` or `float64`.

      Returns:
        A Tensor with same shape and dtype as `z`.
      """
            grad_z = 0.5 * (1. + result *
                            (bessel_iv_ratio(v + 1., z) - 2. * result - 2. *
                             (v - 1) / z)) * dy

            # We don't have an easily computable gradient with respect to v at the
            # moment, so ignore that for now.
            _, grad_z = _fix_gradient_for_broadcasting(v, z,
                                                       tf.ones_like(grad_z),
                                                       grad_z)
            return None, grad_z

        return result, grad
Пример #6
0
def _sample_multinomial_as_iterated_binomial(
    num_samples, num_classes, probs, num_trials, dtype, seed):
  """Sample a multinomial by drawing one binomial sample per class.

  The batch shape is given by broadcasting num_trials with
  remove_last_dimension(probs).

  The loop over binomial samples is a `tf.while_loop`, thus supporting a dynamic
  number of classes.

  Args:
    num_samples: Singleton integer Tensor: number of multinomial samples to
      draw.
    num_classes: Singleton integer Tensor: number of classes.
    probs: Floating Tensor with last dimension `num_classes`, of normalized
      probabilities per class.
    num_trials: Tensor of number of categorical trials each multinomial consists
      of.  num_trials[..., tf.newaxis] must broadcast with probs.
    dtype: dtype at which to emit samples.
    seed: Random seed.

  Returns:
    samples: Tensor of given dtype and shape [num_samples] + batch_shape +
      [num_classes].
  """
  with tf.name_scope('draw_sample'):
    # `convert_to_tensor(num_classes) here to avoid unstacking inside
    # `split_seed`.  We can't take advantage of the Python-list code path anyway
    # because the index at which we will take the seed is a Tensor.
    seeds = samplers.split_seed(
        seed, n=tf.convert_to_tensor(num_classes),
        salt='multinomial_draw_sample')

    def fn(i, num_trials, consumed_prob, accum):
      """Sample the counts for one class using binomial."""
      probs_here = tf.gather(probs, i, axis=-1)
      binomial_probs = tf.clip_by_value(probs_here / (1. - consumed_prob), 0, 1)
      seed_here = tf.gather(seeds, i, axis=0)
      binom = binomial.Binomial(total_count=num_trials, probs=binomial_probs)
      # Not passing `num_samples` to `binom.sample`, as it's is already in
      # `num_trials.shape`.
      sample = binom.sample(seed=seed_here)
      accum = accum.write(i, tf.cast(sample, dtype=dtype))
      return i + 1, num_trials - sample, consumed_prob + probs_here, accum

    num_trials = tf.cast(num_trials, probs.dtype)
    # Pre-broadcast with probs
    num_trials = num_trials + tf.zeros_like(probs[..., 0])
    # Pre-enlarge for different output samples
    num_trials = _replicate_along_left(num_trials, num_samples)
    i = tf.constant(0)
    consumed_prob = tf.zeros_like(probs[..., 0])
    accum = tf.TensorArray(
        dtype, size=num_classes, element_shape=num_trials.shape)
    _, num_trials_left, _, accum = tf.while_loop(
        cond=lambda index, _0, _1, _2: tf.less(index, num_classes - 1),
        body=fn,
        loop_vars=(i, num_trials, consumed_prob, accum))
    # Force the last iteration to put all the trials into the last bucket,
    # because probs[..., -1] / (1. - consumed_prob) might numerically not be 1.
    # Also saves one iteration around the while_loop and one run of the binomial
    # sampler.
    accum = accum.write(num_classes - 1, tf.cast(num_trials_left, dtype=dtype))
    # This stop_gradient is necessary to prevent spurious zero gradients coming
    # from b/138796859, and a spurious gradient through num_trials_left.
    results = tf.stop_gradient(accum.stack())
    return distribution_util.move_dimension(results, 0, -1)
Пример #7
0
def make_convolution_transpose_fn_with_subkernels(filter_shape,
                                                  strides,
                                                  padding,
                                                  rank=2,
                                                  dilations=None,
                                                  dtype=tf.int32,
                                                  validate_args=False,
                                                  name=None):
    """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
    with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'):

        # Enable v2 control flow to avoid None gradients through TensorArray.
        tf.compat.v1.enable_control_flow_v2()

        if tf.get_static_value(rank) != 2:
            raise NotImplementedError(
                'Argument `rank` currently only supports `2`; '
                'saw "{}".'.format(rank))
        [
            filter_shape,
            rank,
            strides,
            padding,
            dilations,
        ] = prepare_conv_args(filter_shape,
                              rank=rank,
                              strides=strides,
                              padding=padding,
                              dilations=dilations,
                              is_transpose=True,
                              validate_args=validate_args)

        sh, sw = strides
        fh, fw = filter_shape
        dh, dw = dilations

        # Determine maximum filter height and filter width of sub-kernels.
        sub_fh = (fh - 1) // sh + 1
        sub_fw = (fw - 1) // sw + 1

        def loop_body(i_, kernels_ind):
            i = i_ // sw
            j = i_ % sw

            i_ind = ps.range(i * fw,
                             ps.maximum(i, fh) * fw,
                             delta=sh * fw,
                             dtype=dtype)
            j_ind = ps.range(j, ps.maximum(j, fw), delta=sw, dtype=dtype)

            last_j = sw - (fw - j - 1) % sw - 1
            last_i = sh - (fh - i - 1) % sh - 1
            pos = last_i * sw + last_j

            nc = cartesian_add([i_ind, j_ind])
            kernels_ind = kernels_ind.write(
                pos, ps.reverse(ps.reverse(nc, [0]), [1]))
            return i_ + 1, kernels_ind

        kernels_ind = tf.TensorArray(dtype=dtype,
                                     infer_shape=False,
                                     size=sh * sw)

        _, kernels_ind = tf.while_loop(lambda i, _: i < sh * sw, loop_body,
                                       [0, kernels_ind])

        tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding(
            fh, stride=sh, dilation=dh, padding=padding)
        tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding(
            fw, stride=sw, dilation=dw, padding=padding)

        pad_bottom = (tot_pad_bottom - 1) // sh + 1
        pad_top = (tot_pad_top - 1) // sh + 1
        pad_right = (tot_pad_right - 1) // sw + 1
        pad_left = (tot_pad_left - 1) // sw + 1
        padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right))

        truncate_top = pad_top * sh - tot_pad_top
        truncate_left = pad_left * sw - tot_pad_left

        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  c_out, batch_shape,
                                                  event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x_pad = tf.pad(x, paddings=paddings, constant_values=0)

                ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1
                ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1

                def loop_body(i, outputs):
                    subkernel_ind = kernels_ind.read(i)
                    fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2)
                    eh = ex_h + fh_ - 1
                    ew = ex_w + fw_ - 1

                    subkernel_ind = ps.reshape(ps.reshape(
                        subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] +
                                               ps.range(c_in),
                                               shape=[-1])

                    k = tf.gather(kernel, subkernel_ind, axis=-2)
                    ind, shape = im2row_index([eh, ew, c_in],
                                              block_shape=(fh_, fw_),
                                              slice_step=(1, 1),
                                              dilations=dilations)
                    x_i = x_pad[..., :eh, :ew, :]
                    x_i_shape = ps.shape(x_i)
                    flat_shape = ps.pad(x_i_shape[:-3],
                                        paddings=[[0, 1]],
                                        constant_values=-1)
                    flat_x = tf.reshape(x_i, flat_shape)
                    x_ = tf.gather(flat_x, ind, axis=-1)
                    im_x = tf.reshape(
                        x_, ps.concat([x_i_shape[:-3], shape], axis=0))
                    outputs = outputs.write(
                        i,
                        tf.matmul(
                            im_x,
                            tf.reshape(
                                k,
                                ps.concat([
                                    kernel_batch, [1, fh_ * fw_ * c_in, c_out]
                                ],
                                          axis=0))))
                    return i + 1, outputs

                outputs = tf.TensorArray(dtype=input_dtype, size=sh * sw)

                _, outputs = tf.while_loop(lambda i, _: i < sh * sw, loop_body,
                                           [0, outputs])

                y = outputs.concat()

                m = tf.reduce_prod(ps.shape(y)[:-3])
                y_ = tf.reshape(y,
                                shape=ps.concat([[m], ps.shape(y)[-3:]],
                                                axis=0))
                y2 = tf.batch_to_space(y_,
                                       strides,
                                       crops=tf.zeros([2, 2], dtype=tf.int64))
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)
                y2 = tf.reshape(
                    y2,
                    ps.concat([broadcast_batch_shape,
                               ps.shape(y2)[-3:]],
                              axis=0))

                out_height = _deconv_output_length(xh,
                                                   filter_size=fh,
                                                   padding=padding,
                                                   output_padding=None,
                                                   stride=sh,
                                                   dilation=dh)
                out_width = _deconv_output_length(xw,
                                                  filter_size=fw,
                                                  padding=padding,
                                                  output_padding=None,
                                                  stride=sw,
                                                  dilation=dw)

                return y2[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]

        return op
Пример #8
0
    def _solve(
        self,
        ode_fn,
        initial_time,
        initial_state,
        solution_times,
        jacobian_fn=None,
        jacobian_sparsity=None,
        batch_ndims=None,
        previous_solver_internal_state=None,
    ):
        # Static assertions
        del jacobian_fn, jacobian_sparsity  # not used by DormandPrince
        if batch_ndims is not None and batch_ndims != 0:
            raise NotImplementedError(
                'For homogeneous batching use `batch_ndims=0`.')
        solution_times_by_solver = isinstance(solution_times,
                                              base.ChosenBySolver)

        with tf.name_scope(self._name):
            # (2) Convert to tensors, determined dtypes.
            get_dtype = lambda x: x.dtype
            error_if_wrong_dtype = functools.partial(
                util.error_if_not_real_or_complex, identifier='initial_state')

            initial_state = tf.nest.map_structure(tf.convert_to_tensor,
                                                  initial_state)
            tf.nest.map_structure(error_if_wrong_dtype, initial_state)

            state_dtypes = tf.nest.map_structure(get_dtype, initial_state)
            common_state_dtype = dtype_util.common_dtype(initial_state)
            real_dtype = dtype_util.real_dtype(common_state_dtype)

            initial_time = tf.cast(initial_time, real_dtype)
            max_num_steps = self._max_num_steps
            max_ode_fn_evals = self._max_num_steps
            if max_num_steps is not None:
                max_num_steps = tf.convert_to_tensor(max_num_steps,
                                                     dtype=tf.int32)
                max_ode_fn_evals = max_num_steps * self.ODE_FN_EVALS_PER_STEP
            step_size = tf.convert_to_tensor(self._first_step_size,
                                             dtype=real_dtype)
            rtol = tf.convert_to_tensor(tf.cast(self._rtol, real_dtype))
            atol = tf.convert_to_tensor(tf.cast(self._atol, real_dtype))
            safety = tf.convert_to_tensor(self._safety_factor,
                                          dtype=real_dtype)
            # Use i(d)factor notation for increasing and decreasing factors.
            ifactor, dfactor = self._max_step_size_factor, self._min_step_size_factor
            ifactor = tf.convert_to_tensor(ifactor, dtype=real_dtype)
            dfactor = tf.convert_to_tensor(dfactor, dtype=real_dtype)

            solver_internal_state = previous_solver_internal_state
            if solver_internal_state is None:
                initial_derivative = ode_fn(initial_time, initial_state)
                initial_derivative = tf.nest.map_structure(
                    tf.convert_to_tensor, initial_derivative)
                solver_internal_state = _RungeKuttaSolverInternalState(
                    current_state=initial_state,
                    current_derivative=initial_derivative,
                    last_step_start=initial_time,
                    current_time=initial_time,
                    step_size=step_size,
                    interpolating_coefficients=[initial_state] * self.ORDER)

            num_solution_times = 0
            if solution_times_by_solver:
                final_time = tf.cast(solution_times.final_time, real_dtype)
                times_array = tf.TensorArray(real_dtype,
                                             size=num_solution_times,
                                             dynamic_size=True,
                                             element_shape=tf.TensorShape([]))
            else:
                solution_times = tf.cast(solution_times, real_dtype)
                util.error_if_not_vector(solution_times, 'solution_times')
                num_solution_times = tf.size(solution_times)
                times_array = tf.TensorArray(
                    real_dtype,
                    size=num_solution_times,
                    dynamic_size=False,
                    element_shape=[]).unstack(solution_times)

            solutions_arrays = [
                tf.TensorArray(dtype=component_dtype,
                               size=num_solution_times,
                               dynamic_size=solution_times_by_solver)
                for component_dtype in tf.nest.flatten(state_dtypes)
            ]
            solutions_arrays = tf.nest.pack_sequence_as(
                initial_state, solutions_arrays)

            rk_step = functools.partial(self._step,
                                        max_ode_fn_evals=max_ode_fn_evals,
                                        ode_fn=ode_fn,
                                        atol=atol,
                                        rtol=rtol,
                                        safety=safety,
                                        ifactor=ifactor,
                                        dfactor=dfactor)
            advance_to_solution_time = functools.partial(
                _advance_to_solution_time,
                times_array=solution_times,
                step_fn=rk_step,
                validate_args=self._validate_args)

            assert_ops = self._assert_ops(
                ode_fn=ode_fn,
                initial_time=initial_time,
                initial_state=initial_state,
                solution_times=solution_times,
                previous_solver_state=previous_solver_internal_state,
                rtol=rtol,
                atol=atol,
                first_step_size=step_size,
                safety_factor=safety,
                min_step_size_factor=ifactor,
                max_step_size_factor=dfactor,
                max_num_steps=max_num_steps,
                solution_times_by_solver=solution_times_by_solver)
            with tf.control_dependencies(assert_ops):
                ode_evals_by_now = 1 if self._validate_args else 0
                ode_evals_by_now += 1 if solver_internal_state is None else 0
                diagnostics = _DopriDiagnostics(
                    num_ode_fn_evaluations=ode_evals_by_now,
                    num_jacobian_evaluations=0,
                    num_matrix_factorizations=0,
                    status=0)

                if solution_times_by_solver:
                    r = _dense_solutions_to_final_time(
                        final_time=final_time,
                        solver_state=solver_internal_state,
                        diagnostics=diagnostics,
                        step_fn=rk_step,
                        ode_fn=ode_fn,
                        times_array=times_array,
                        solutions_arrays=solutions_arrays,
                        validate_args=self._validate_args)
                    solver_internal_state, diagnostics, times_array, solutions_arrays = r
                else:

                    def iterate_cond(time_id, *_):
                        return time_id < num_solution_times

                    [_, solver_internal_state, diagnostics, solutions_arrays
                     ] = tf.while_loop(iterate_cond,
                                       advance_to_solution_time, [
                                           0, solver_internal_state,
                                           diagnostics, solutions_arrays
                                       ],
                                       back_prop=False)

                times = times_array.stack()
                stack_components = lambda x: x.stack()
                states = tf.nest.map_structure(stack_components,
                                               solutions_arrays)
                return base.Results(
                    times=times,
                    states=states,
                    diagnostics=diagnostics,
                    solver_internal_state=solver_internal_state)
Пример #9
0
def _dense_solutions_to_final_time(final_time,
                                   solver_state,
                                   diagnostics,
                                   step_fn,
                                   ode_fn,
                                   times_array,
                                   solutions_arrays,
                                   validate_args=False):
    """Integrates `solver_state` to `final_time`.

  Performs integration of the `solver_state` to `final_time` while saving
  solutions at all intermediate time steps. This corresponds to the expected
  behavior of `ChosenBySolver` option. The solution at `final_time` is obtained
  by interpolation and is set as a final state of the solver.

  Args:
    final_time: Floating `Tensor` representing the final time of integration.
    solver_state: `_DopriSolverInternalState` - initial solver state.
    diagnostics: `_DopriDiagnostics` - info on the current `_solve` call.
    step_fn: Partial `Dopri._step` method that performs a single step updating
      the `solver_state`, `diagnostics` and `solver_state`.
    ode_fn: Callable(t, y) -> dy_dt.
    times_array: `TensorArray` where time values are recorded.
    solutions_arrays: `TensorArray`s where solutions are recorded.
    validate_args: Python `bool` indicating whether to validate inputs.
      Default value: False.

  Returns:
    solver_state: `_RungeKuttaSolverInternalState` holding final solver state.
    diagnostics: `_DopriDiagnostics` holding diagnostic values.
    times_array: `TensorArray` with recorded solution times.
    solutions_arrays: `TensorArray`s with solution values at time corresponding
      to times_array.
  """
    def step_and_record(solver_state, diagnostics, solutions_arrays,
                        times_array):
        y = solver_state.current_state
        time_id = times_array.size()
        solutions_arrays = _write_solution_components(y, solutions_arrays,
                                                      time_id)
        times_array = times_array.write(time_id, solver_state.current_time)
        solver_state, diagnostics = step_fn(solver_state, diagnostics)
        return (solver_state, diagnostics, solutions_arrays, times_array)

    def step_cond(solver_internal_state, *_):
        return solver_internal_state.current_time <= final_time

    [solver_state, diagnostics, solutions_arrays, times_array] = tf.while_loop(
        step_cond,
        step_and_record,
        [solver_state, diagnostics, solutions_arrays, times_array],
        back_prop=False)
    # Interpolating the last time point, updating the state and write results.
    y, coefficients = _interpolate_solution_at(final_time, solver_state,
                                               validate_args)
    dy_dt = ode_fn(final_time, y)
    dy_dt = tf.nest.map_structure(tf.convert_to_tensor, dy_dt)

    time_id = times_array.size()
    times_array = times_array.write(time_id, final_time)
    solutions_arrays = _write_solution_components(y, solutions_arrays, time_id)
    solver_state = _RungeKuttaSolverInternalState(
        current_state=y,
        current_derivative=dy_dt,
        last_step_start=solver_state.last_step_start,
        current_time=final_time,
        step_size=solver_state.step_size,
        interpolating_coefficients=coefficients)
    return solver_state, diagnostics, times_array, solutions_arrays
Пример #10
0
def iterative_mergesort(y, permutation, name=None):
  """Non-recusive mergesort that counts exchanges.

  Args:
    y: a `Tensor` of shape `[n]` containing values to be sorted.
    permutation: `Tensor` of shape `[n]` with original ordering.
    name: Optional Python `str` name for ops created by this method.
      Default value: `None` (i.e., 'iterative_mergesort').

  Returns:
    exchanges: `int32` scalar that counts the number of exchanges required to
      produce a sorted permutation
    permutation: and a `tf.int32` Tensor that contains the ordering of y values
      that are sorted.
  """

  with tf.name_scope(name or 'iterative_mergesort'):
    y = tf.convert_to_tensor(y, name='y')
    permutation = tf.convert_to_tensor(
        permutation, name='permutation', dtype=tf.int32)
    shape = permutation.shape
    tensorshape_util.assert_is_compatible_with(y.shape, shape)
    n = ps.size(y)

    def outer_body(k, exchanges, permutation):
      # The outer body progressively merges lists as k grows by powers of 2,
      # tracking the total swaps required in exchanges as the new permutation is
      # built in place.
      y_ordered = tf.gather(y, permutation)

      def middle_body(left, exchanges, permutation):
        # the middle body advances through the sublists of size k, advancing
        # the left edge until the end of the input is reached.
        right = left + k
        end = tf.minimum(right + k, n)

        # See explanation here
        # https://www.geeksforgeeks.org/counting-inversions/.

        def inner_body(i, j, x, np, p):
          # The [left, right) and [right, end) lists are merged sorted, with
          # i and j tracking the advance through each range. x records the
          # number of order (bubble-sort equivalent) swaps that are happening
          # with each insertion, and np represents the size of the output
          # permutation that's been filled in using the p tensor.
          y_less = y_ordered[i] <= y_ordered[j]
          element = tf.where(y_less, [permutation[i]], [permutation[j]])
          new_p = tf.concat([p[0:np], element, p[np + 1:n]], axis=0)
          tensorshape_util.set_shape(new_p, p.shape)
          return (tf.where(y_less, i + 1, i), tf.where(y_less, j, j + 1),
                  tf.where(y_less, x, x + right - i), np + 1, new_p)

        i_j_x_np_p = (left, right, exchanges, 0, tf.zeros([n], dtype=tf.int32))
        (i, j, exchanges, np, p) = tf.while_loop(
            cond=lambda i, j, x, np, p: tf.math.logical_and(i < right, j < end),
            body=inner_body,
            loop_vars=i_j_x_np_p)
        permutation = tf.concat([
            permutation[0:left], p[0:np], permutation[i:right],
            permutation[j:end], permutation[end:n]
        ],
                                axis=0)
        tensorshape_util.set_shape(permutation, shape)
        return left + 2 * k, exchanges, permutation

      _, exchanges, permutation = tf.while_loop(
          cond=lambda left, exchanges, permutation: left < n - k,
          body=middle_body,
          loop_vars=(0, exchanges, permutation))
      k *= 2
      return k, exchanges, permutation

    _, exchanges, permutation = tf.while_loop(
        cond=lambda k, exchanges, permutation: k < n,
        body=outer_body,
        loop_vars=(1, 0, permutation))
    return exchanges, permutation
Пример #11
0
            def grad_fn(*dresults, **kwargs):
                """Adjoint sensitivity method to compute gradients."""
                dresults = tf.nest.pack_sequence_as(results, dresults)
                dstates = dresults.states
                # The signature grad_fn(*dresults, variables=None) is not valid Python 2
                # so use kwargs instead.
                variables = kwargs.pop('variables', [])
                assert not kwargs  # This assert should never fail.
                # TODO(b/138304303): Support complex types.
                with tf.name_scope('{}Gradients'.format(self._name)):
                    get_dtype = lambda x: x.dtype

                    def error_if_complex(dtype):
                        if dtype.is_complex:
                            raise NotImplementedError(
                                'The adjoint sensitivity method does '
                                'not support complex dtypes.')

                    state_dtypes = tf.nest.map_structure(
                        get_dtype, initial_state)
                    tf.nest.map_structure(error_if_complex, state_dtypes)
                    common_state_dtype = dtype_util.common_dtype(initial_state)
                    real_dtype = dtype_util.real_dtype(common_state_dtype)

                    # We add initial_time to ensure that we know where to stop.
                    result_times = tf.concat(
                        [[tf.cast(initial_time, real_dtype)], results.times],
                        0)
                    num_result_times = tf.size(result_times)

                    # First two components correspond to reverse and adjoint states.
                    # the last component is adjoint state for variables.
                    terminal_augmented_state = tuple([
                        rk_util.nest_constant(initial_state, 0.0),
                        rk_util.nest_constant(initial_state, 0.0),
                        tuple(
                            rk_util.nest_constant(variable, 0.0)
                            for variable in variables)
                    ])

                    # The XLA compiler does not compile code which slices/indexes using
                    # integer `Tensor`s. `TensorArray`s are used to get around this.
                    result_time_array = tf.TensorArray(
                        results.times.dtype,
                        clear_after_read=False,
                        size=num_result_times,
                        element_shape=[]).unstack(result_times)

                    # TensorArray shape should not include time dimension, hence shape[1:]
                    result_state_arrays = [
                        tf.TensorArray(  # pylint: disable=g-complex-comprehension
                            dtype=component.dtype,
                            size=num_result_times - 1,
                            element_shape=component.shape[1:]).unstack(
                                component)
                        for component in tf.nest.flatten(results.states)
                    ]
                    result_state_arrays = tf.nest.pack_sequence_as(
                        results.states, result_state_arrays)
                    dresult_state_arrays = [
                        tf.TensorArray(  # pylint: disable=g-complex-comprehension
                            dtype=component.dtype,
                            size=num_result_times - 1,
                            element_shape=component.shape[1:]).unstack(
                                component)
                        for component in tf.nest.flatten(dstates)
                    ]
                    dresult_state_arrays = tf.nest.pack_sequence_as(
                        results.states, dresult_state_arrays)

                    def augmented_ode_fn(backward_time, augmented_state):
                        """Dynamics function for the augmented system.

            Describes a differential equation that evolves the augmented state
            backwards in time to compute gradients using the adjoint method.
            Augmented state consists of 3 components `(state, adjoint_state,
            vars)` all evaluated at time `backward_time`:

            state: represents the solution of user provided `ode_fn`. The
              structure coincides with the `initial_state`.
            adjoint_state: represents the solution of adjoint sensitivity
              differential equation as discussed below. Has the same structure
              and shape as `state`.
            vars: represent the solution of the adjoint equation for variable
              gradients. Represented as a `Tuple(Tensor, ...)` with as many
              tensors as there are `variables`.

            Adjoint sensitivity equation describes the gradient of the solution
            with respect to the value of the solution at previous time t. Its
            dynamics are given by
            d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z)
            Which is computed as:
            d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)]
            d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
            where in the last line we moved adj(t)_j under derivative by
            removing gradient from it.

            Adjoint equation for the gradient with respect to every
            `tf.Variable` theta follows:
            d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta)
            = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]

            Args:
              backward_time: Floating `Tensor` representing current time.
              augmented_state: `Tuple(state, adjoint_state, variable_grads)`

            Returns:
              negative_derivatives: Structure of `Tensor`s equal to backwards
                time derivative of the `state` componnent.
              adjoint_ode: Structure of `Tensor`s equal to backwards time
                derivative of the `adjoint_state` component.
              adjoint_variables_ode: Structure of `Tensor`s equal to backwards
                time derivative of the `vars` component.
            """
                        # The negative signs disappears after the change of variables.
                        # The ODE solver cannot handle the case initial_time > final_time
                        # and hence a change of variables backward_time = -time is used.
                        time = -backward_time
                        state, adjoint_state, _ = augmented_state

                        with tf.GradientTape() as tape:
                            tape.watch(variables)
                            tape.watch(state)
                            derivatives = ode_fn(time, state)
                            adjoint_no_grad = tf.nest.map_structure(
                                tf.stop_gradient, adjoint_state)
                            negative_derivatives = rk_util.weighted_sum(
                                [-1.0], [derivatives])

                            def dot_prod(tensor_a, tensor_b):
                                return tf.reduce_sum(tensor_a * tensor_b)

                            # See docstring for details.
                            adjoint_dot_derivatives = tf.nest.map_structure(
                                dot_prod, adjoint_no_grad, derivatives)
                            adjoint_dot_derivatives = tf.squeeze(
                                tf.add_n(
                                    tf.nest.flatten(adjoint_dot_derivatives)))

                        adjoint_ode, adjoint_variables_ode = tape.gradient(
                            adjoint_dot_derivatives, (state, tuple(variables)),
                            unconnected_gradients=tf.UnconnectedGradients.ZERO)
                        return negative_derivatives, adjoint_ode, adjoint_variables_ode

                    def reverse_to_result_time(n, augmented_state, _):
                        """Integrates the augmented system backwards in time."""
                        lower_bound_of_integration = result_time_array.read(n)
                        upper_bound_of_integration = result_time_array.read(n -
                                                                            1)
                        _, adjoint_state, adjoint_variable_state = augmented_state
                        initial_state = _read_solution_components(
                            result_state_arrays, input_state_structure, n - 1)
                        initial_adjoint = _read_solution_components(
                            dresult_state_arrays, input_state_structure, n - 1)
                        initial_adjoint_state = rk_util.weighted_sum(
                            [1.0, 1.0], [adjoint_state, initial_adjoint])
                        initial_augmented_state = (initial_state,
                                                   initial_adjoint_state,
                                                   adjoint_variable_state)
                        augmented_results = self._solve(
                            ode_fn=augmented_ode_fn,
                            initial_time=-lower_bound_of_integration,
                            initial_state=initial_augmented_state,
                            solution_times=[-upper_bound_of_integration],
                            batch_ndims=batch_ndims)
                        # Results added an extra time dim of size 1, squeeze it.
                        select_result = lambda x: tf.squeeze(x, [0])
                        result_state = augmented_results.states
                        result_state = tf.nest.map_structure(
                            select_result, result_state)
                        status = augmented_results.diagnostics.status
                        return n - 1, result_state, status

                    _, augmented_state, _ = tf.while_loop(
                        lambda n, _, status: (n >= 1) & tf.equal(status, 0),
                        reverse_to_result_time,
                        (num_result_times - 1, terminal_augmented_state, 0),
                    )
                    _, adjoint_state, adjoint_variables = augmented_state
                    return adjoint_state, list(adjoint_variables)
Пример #12
0
def kendalls_tau(y_true, y_pred, name=None):
  """Computes Kendall's Tau for two ordered lists.

  Kendall's Tau measures the correlation between ordinal rankings. This
  implementation is similar to the one used in scipy.stats.kendalltau.
  The provided values may be of any type that is sortable, with the
  argsort indices indicating the true or proposed ordinal sequence.

  Args:
    y_true: a `Tensor` of shape `[n]` containing the true ordinal ranking.
    y_pred: a `Tensor` of shape `[n]` containing the predicted ordering of the
      same N items.
    name: Optional Python `str` name for ops created by this method.
      Default value: `None` (i.e., 'kendalls_tau').

  Returns:
    kendalls_tau: Kendall's Tau, the 1945 tau-b formulation that ignores
      ordering of ties, as a `float32` scalar Tensor.
  """
  with tf.name_scope(name or 'kendalls_tau'):
    in_type = dtype_util.common_dtype([y_true, y_pred], dtype_hint=tf.float32)
    y_true = tf.convert_to_tensor(y_true, name='y_true', dtype=in_type)
    y_pred = tf.convert_to_tensor(y_pred, name='y_pred', dtype=in_type)
    tensorshape_util.assert_is_compatible_with(y_true.shape, y_pred.shape)
    assertions = [
        assert_util.assert_rank(y_true, 1),
        assert_util.assert_greater(
            ps.size(y_true), 1, 'Ordering requires at least 2 elements.')
    ]
    with tf.control_dependencies(assertions):
      lexa = lexicographical_indirect_sort(y_true, y_pred)

    # See A Computer Method for Calculating Kendall's Tau with Ungrouped Data
    # by William Night, Journal of the American Statistical Association,
    # Jun., 1966, Vol. 61, No. 314, Part 1 (Jun., 1966), pp. 436-439
    # for notation https://www.jstor.org/stable/2282833

    def jointly_tied_pairs_body(first, t, i):
      not_equal = tf.math.logical_or(
          tf.not_equal(y_true[lexa[first]], y_true[lexa[i]]),
          tf.not_equal(y_pred[lexa[first]], y_pred[lexa[i]]))
      return (tf.where(not_equal, i, first),
              tf.where(not_equal, t + ((i - first) * (i - first - 1)) // 2,
                       t), i + 1)

    n = ps.size0(y_true)
    first, t, _ = tf.while_loop(
        cond=lambda first, t, i: i < n,
        body=jointly_tied_pairs_body,
        loop_vars=(0, 0, 1))
    t += ((n - first) * (n - first - 1)) // 2

    def ties_y_true_body(first, v, i):
      not_equal = tf.not_equal(y_true[lexa[first]], y_true[lexa[i]])
      return (tf.where(not_equal, i, first),
              tf.where(not_equal, v + ((i - first) * (i - first - 1)) // 2,
                       v), i + 1)

    first, v, _ = tf.while_loop(
        cond=lambda first, v, i: i < n,
        body=ties_y_true_body,
        loop_vars=(0, 0, 1))
    v += ((n - first) * (n - first - 1)) // 2

    # count exchanges
    exchanges, newperm = iterative_mergesort(y_pred, lexa)

    def ties_in_y_pred_body(first, u, i):
      not_equal = tf.not_equal(y_pred[newperm[first]], y_pred[newperm[i]])
      return (tf.where(not_equal, i, first),
              tf.where(not_equal, u + ((i - first) * (i - first - 1)) // 2,
                       u), i + 1)

    first, u, _ = tf.while_loop(
        cond=lambda first, u, i: i < n,
        body=ties_in_y_pred_body,
        loop_vars=(0, 0, 1))
    u += ((n - first) * (n - first - 1)) // 2
    n0 = (n * (n - 1)) // 2
    assertions = [
        assert_util.assert_less(v, tf.cast(n0, tf.int32),
                                'All ranks are ties for y_true.'),
        assert_util.assert_less(u, tf.cast(n0, tf.int32),
                                'All ranks are ties for y_pred.')
    ]
    with tf.control_dependencies(assertions):
      return (tf.cast(n0 - (u + v - t), tf.float32) -
              2.0 * tf.cast(exchanges, tf.float32)) / tf.math.sqrt(
                  tf.cast(n0 - v, tf.float32) * tf.cast(n0 - u, tf.float32))
Пример #13
0
def _sample_with_shrinkage(x_initial,
                           target_log_prob,
                           log_slice_heights,
                           step_size,
                           lower_bounds,
                           upper_bounds,
                           seed,
                           name=None):
    """Samples from the slice by applying shrinkage for rejected points.

  Implements the one dimensional slice sampling algorithm of Neal (2003), with a
  doubling algorithm (Neal 2003 P715 Fig. 4), which doubles the size of the
  interval at each iteration and shrinkage (Neal 2003 P716 Fig. 5), which
  reduces the width of the slice when a selected point is rejected, by setting
  the relevant bound that that value. Randomly sampled points are checked for
  two criteria: that they lie within the slice and that they pass the
  acceptability check (Neal 2003 P717 Fig. 6), which tests that the new state
  could have generated the previous one.

  Args:
    x_initial: A tensor of any shape. The initial positions of the chains. This
      function assumes that all the dimensions of `x_initial` are batch
      dimensions (i.e. the event shape is `[]`).
    target_log_prob: Callable accepting a tensor like `x_initial` and returning
      a tensor containing the log density at that point of the same shape.
    log_slice_heights: Tensor of the same shape and dtype as the return value
      of `target_log_prob` when applied to `x_initial`. The log of the height of
      the chosen slice.
    step_size: A tensor of shape and dtype compatible with `x_initial`. The min
      interval size in the doubling algorithm.
    lower_bounds: Tensor of same shape and dtype as `x_initial`. Slice lower
      bounds for each chain.
    upper_bounds: Tensor of same shape and dtype as `x_initial`. Slice upper
      bounds for each chain.
    seed: Tensor seed pair. The random seed.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'find_slice_bounds').

  Returns:
    x_proposed: A tensor of the same shape and dtype as `x_initial`. The next
      proposed state of the chain.
  """
    with tf.name_scope(name or 'sample_with_shrinkage'):
        seed = samplers.sanitize_seed(seed)
        # Keeps track of whether an acceptable sample has been found for the chain.
        found = tf.zeros_like(x_initial, dtype=tf.bool)
        cond = lambda found, *ignored_args: ~tf.reduce_all(found)
        x_next = tf.identity(x_initial)
        x_initial_shape = ps.shape(x_initial)
        x_initial_dtype = dtype_util.base_dtype(x_initial.dtype)

        def _body(found, seed, left, right, x_next):
            """Iterates until every chain has found a suitable next state."""
            proportions_seed, next_seed = samplers.split_seed(seed)
            proportions = samplers.uniform(x_initial_shape,
                                           dtype=x_initial_dtype,
                                           seed=proportions_seed)
            x_proposed = tf.where(~found, left + proportions * (right - left),
                                  x_next)
            accept_res = _test_acceptance(x_initial,
                                          target_log_prob=target_log_prob,
                                          decided=found,
                                          log_slice_heights=log_slice_heights,
                                          x_proposed=x_proposed,
                                          step_size=step_size,
                                          lower_bounds=left,
                                          upper_bounds=right)
            boundary_test = log_slice_heights < target_log_prob(x_proposed)
            can_accept = boundary_test & accept_res
            next_found = found | can_accept
            # Note that it might seem that we are moving the left and right end points
            # even if the point has been accepted (which is contrary to the stated
            # algorithm in Neal). However, this does not matter because the endpoints
            # for points that have been already accepted are not used again so it
            # doesn't matter what we do with them.
            next_left = tf.where(x_proposed < x_initial, x_proposed, left)
            next_right = tf.where(x_proposed >= x_initial, x_proposed, right)
            return (next_found, next_seed, next_left, next_right, x_proposed)

        return tf.while_loop(cond=cond,
                             body=_body,
                             loop_vars=(found, seed, lower_bounds,
                                        upper_bounds, x_next))[-1]
Пример #14
0
def _test_acceptance(x_initial,
                     target_log_prob,
                     decided,
                     log_slice_heights,
                     x_proposed,
                     step_size,
                     lower_bounds,
                     upper_bounds,
                     name=None):
    """Ensures the chosen point does not violate reversibility.

    Implements Fig 6 of Neal 2003 page 717, which checks that the path from the
    existing point to the new point would also have been possible in reverse.
    This is done by checking that the algorithm would not have been terminated
    before reaching the old point.

  Args:
    x_initial: A tensor of any shape and real dtype. The initial positions of
      the chains. This function assumes that all the dimensions of `x_initial`
      are batch dimensions (i.e. the event shape is `[]`).
    target_log_prob: Callable accepting a tensor like `x_initial` and returning
      a tensor containing the log density at that point of the same shape.
    decided: A `tf.bool` tensor of the same shape as `x_initial`. Indicates
      whether the acceptance has already been decided. A point is tested only
      if `decided` for that point is False.
    log_slice_heights: Tensor of the same shape and dtype as the return value
      of `target_log_prob` when applied to `x_initial`. The log of the height of
      the chosen slice.
    x_proposed: A tensor of the same shape and dtype as `x_initial`. The
      proposed points.
    step_size: A tensor of shape and dtype compatible with `x_initial`. The min
      interval size in the doubling algorithm.
    lower_bounds: Tensor of same shape and dtype as `x_initial`. Slice lower
      bounds for each chain.
    upper_bounds: Tensor of same shape and dtype as `x_initial`. Slice upper
      bounds for each chain.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'find_slice_bounds').

  Returns:
    acceptable: A boolean tensor of same shape as `x_initial` indicating whether
      the proposed points are acceptable for reversibility or not.
  """
    with tf.name_scope(name or 'test_acceptance'):
        d = tf.zeros_like(x_initial, dtype=tf.bool)

        # Keeps track of points for which the loop has "effectively terminated".
        # Termination is when either their interval width has shrunk to the minimum
        # value (step_size) or if the point has already been rejected.
        def cond(_, decided, *ignored_args):  # pylint: disable=unused-argument
            # Continue until all the points have been decided.
            return ~tf.reduce_all(decided)

        acceptable = tf.ones_like(x_initial, dtype=tf.bool)

        def body(acceptable, decided, left, right, d):
            """Checks reversibility as described on P717 of Neal 2003."""
            midpoint = (left + right) / 2
            divided = (((x_initial < midpoint) & (x_proposed >= midpoint)) |
                       ((x_proposed < midpoint) & (x_initial >= midpoint)))
            next_d = d | divided
            next_right = tf.where(x_proposed < midpoint, midpoint, right)
            next_left = tf.where(x_proposed >= midpoint, midpoint, left)
            left_test = (log_slice_heights >= target_log_prob(next_left))
            right_test = (log_slice_heights >= target_log_prob(next_right))
            unacceptable = next_d & left_test & right_test
            # Logic here: For points which have not already been decided,
            # and are unacceptable, set acceptable to False. For others, let them
            # be as they were.
            now_decided = ~decided & unacceptable
            next_acceptable = tf.where(now_decided, ~unacceptable, acceptable)
            # Decided if (a) was already decided, or
            # (b) the new width is less than 1.1 step_size, or
            # (c) was marked unacceptable.
            next_decided = (decided |
                            (next_right - next_left <= 1.1 * step_size)
                            | now_decided)
            return (next_acceptable, next_decided, next_left, next_right,
                    next_d)

        return tf.while_loop(cond=cond,
                             body=body,
                             loop_vars=(acceptable, decided, lower_bounds,
                                        upper_bounds, d))[0]
Пример #15
0
 def fn():
     x = np.asarray(0)
     c = lambda x: x < 10000
     b = lambda x: [x + 1]
     return tf.while_loop(c, b, [x], parallel_iterations=20)
def sample_annealed_importance_chain(num_steps,
                                     proposal_log_prob_fn,
                                     target_log_prob_fn,
                                     current_state,
                                     make_kernel_fn,
                                     parallel_iterations=10,
                                     seed=None,
                                     name=None):
    """Runs annealed importance sampling (AIS) to estimate normalizing constants.

  This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
  to sample from a series of distributions that slowly interpolates between
  an initial 'proposal' distribution:

  `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)`

  and the target distribution:

  `exp(target_log_prob_fn(x) - target_log_normalizer)`,

  accumulating importance weights along the way. The product of these
  importance weights gives an unbiased estimate of the ratio of the
  normalizing constants of the initial distribution and the target
  distribution:

  `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`.

  Note: When running in graph mode, `proposal_log_prob_fn` and
  `target_log_prob_fn` are called exactly three times (although this may be
  reduced to two times in the future).

  Args:
    num_steps: Integer number of Markov chain updates to run. More
      iterations means more expense, but smoother annealing between q
      and p, which in turn means exponentially lower variance for the
      normalizing constant estimator.
    proposal_log_prob_fn: Python callable that returns the log density of the
      initial distribution.
    target_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the target distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
      `sample_annealed_importance_chain` creates a new `target_log_prob_fn`
      which is an interpolation between the supplied `target_log_prob_fn` and
      `proposal_log_prob_fn`; it is this interpolated function which is used as
      an argument to `make_kernel_fn`.
    parallel_iterations: The number of iterations allowed to run in parallel.
      It must be a positive integer. See `tf.while_loop` for more details.
    seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'sample_annealed_importance_chain').

  Returns:
    next_state: `Tensor` or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at the final iteration. Has same shape as
      input `current_state`.
    ais_weights: Tensor with the estimated weight(s). Has shape matching
      `target_log_prob_fn(current_state)`.
    kernel_results: `collections.namedtuple` of internal calculations used to
      advance the chain.

  #### Examples

  ##### Estimate the normalizing constant of a log-gamma distribution.

  ```python
  tfd = tfp.distributions

  # Run 100 AIS chains in parallel
  num_chains = 100
  dims = 20
  dtype = np.float32

  proposal = tfd.MultivariateNormalDiag(
     loc=tf.zeros([dims], dtype=dtype))

  target = tfd.TransformedDistribution(
    distribution=tfd.Sample(
        tfd.Gamma(concentration=dtype(2), rate=dtype(3)),
        sample_shape=[dims])
    bijector=tfp.bijectors.Invert(tfp.bijectors.Exp()))

  chains_state, ais_weights, kernels_results = (
      tfp.mcmc.sample_annealed_importance_chain(
          num_steps=1000,
          proposal_log_prob_fn=proposal.log_prob,
          target_log_prob_fn=target.log_prob,
          current_state=proposal.sample(num_chains),
          make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn=tlp_fn,
            step_size=0.2,
            num_leapfrog_steps=2)))

  log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights)
                              - np.log(num_chains))
  log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.)
  ```

  ##### Estimate marginal likelihood of a Bayesian regression model.

  ```python
  tfd = tfp.distributions

  def make_prior(dims, dtype):
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros(dims, dtype))

  def make_likelihood(weights, x):
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(weights, x, axes=[[0], [-1]]))

  # Run 100 AIS chains in parallel
  num_chains = 100
  dims = 10
  dtype = np.float32

  # Make training data.
  x = np.random.randn(num_chains, dims).astype(dtype)
  true_weights = np.random.randn(dims).astype(dtype)
  y = np.dot(x, true_weights) + np.random.randn(num_chains)

  # Setup model.
  prior = make_prior(dims, dtype)
  def target_log_prob_fn(weights):
    return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y)

  proposal = tfd.MultivariateNormalDiag(
      loc=tf.zeros(dims, dtype))

  weight_samples, ais_weights, kernel_results = (
      tfp.mcmc.sample_annealed_importance_chain(
        num_steps=1000,
        proposal_log_prob_fn=proposal.log_prob,
        target_log_prob_fn=target_log_prob_fn
        current_state=tf.zeros([num_chains, dims], dtype),
        make_kernel_fn=lambda tlp_fn: tfp.mcmc.HamiltonianMonteCarlo(
          target_log_prob_fn=tlp_fn,
          step_size=0.1,
          num_leapfrog_steps=2)))
  log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights)
                             - np.log(num_chains))
  ```

  """
    is_seeded = seed is not None
    seed = samplers.sanitize_seed(seed, salt='mcmc.sample_ais_chain')

    with tf.name_scope(name or 'sample_annealed_importance_chain'):
        num_steps = tf.convert_to_tensor(value=num_steps,
                                         dtype=tf.int32,
                                         name='num_steps')
        if mcmc_util.is_list_like(current_state):
            current_state = [
                tf.convert_to_tensor(s, name='current_state')
                for s in current_state
            ]
        else:
            current_state = tf.convert_to_tensor(value=current_state,
                                                 name='current_state')

        def _make_convex_combined_log_prob_fn(iter_):
            def _fn(*args):
                p = tf.identity(proposal_log_prob_fn(*args),
                                name='proposal_log_prob')
                t = tf.identity(target_log_prob_fn(*args),
                                name='target_log_prob')
                dtype = dtype_util.base_dtype(p.dtype)
                beta = tf.cast(iter_ + 1, dtype) / tf.cast(num_steps, dtype)
                return tf.identity(beta * t + (1. - beta) * p,
                                   name='convex_combined_log_prob')

            return _fn

        def _loop_body(iter_, seed, ais_weights, current_state,
                       kernel_results):
            """Closure which implements `tf.while_loop` body."""
            iter_seed, next_seed = samplers.split_seed(
                seed,
                salt='ais_chain.seeded_one_step') if is_seeded else (seed,
                                                                     seed)

            x = (current_state
                 if mcmc_util.is_list_like(current_state) else [current_state])
            proposal_log_prob = proposal_log_prob_fn(*x)
            target_log_prob = target_log_prob_fn(*x)
            ais_weights += ((target_log_prob - proposal_log_prob) /
                            tf.cast(num_steps, ais_weights.dtype))
            kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_))
            # TODO(b/147676843): Should we warn if the kernel is not calibrated?
            one_step_kwargs = dict(seed=iter_seed) if is_seeded else {}
            next_state, inner_results = kernel.one_step(
                current_state, kernel_results.inner_results, **one_step_kwargs)
            kernel_results = AISResults(
                proposal_log_prob=proposal_log_prob,
                target_log_prob=target_log_prob,
                inner_results=inner_results,
            )
            return [
                iter_ + 1, next_seed, ais_weights, next_state, kernel_results
            ]

        def _bootstrap_results(init_state):
            """Creates first version of `previous_kernel_results`."""
            kernel = make_kernel_fn(_make_convex_combined_log_prob_fn(iter_=0))
            inner_results = kernel.bootstrap_results(init_state)
            mh_results = _find_inner_mh_results(inner_results)

            convex_combined_log_prob = mh_results.accepted_results.target_log_prob
            dtype = dtype_util.as_numpy_dtype(convex_combined_log_prob.dtype)
            shape = tf.shape(convex_combined_log_prob)
            proposal_log_prob = tf.fill(shape,
                                        dtype(np.nan),
                                        name='bootstrap_proposal_log_prob')
            target_log_prob = tf.fill(shape,
                                      dtype(np.nan),
                                      name='target_target_log_prob')

            return AISResults(
                proposal_log_prob=proposal_log_prob,
                target_log_prob=target_log_prob,
                inner_results=inner_results,
            )

        previous_kernel_results = _bootstrap_results(current_state)
        inner_results = previous_kernel_results.inner_results
        mh_results = _find_inner_mh_results(inner_results)

        ais_weights = tf.zeros(
            shape=tf.broadcast_dynamic_shape(
                tf.shape(mh_results.proposed_results.target_log_prob),
                tf.shape(mh_results.accepted_results.target_log_prob)),
            dtype=mh_results.proposed_results.target_log_prob.dtype)

        [_, _, ais_weights, current_state, kernel_results] = tf.while_loop(
            cond=lambda iter_, *args: iter_ < num_steps,
            body=_loop_body,
            loop_vars=[
                np.int32(0),  # iter_
                seed,
                ais_weights,
                current_state,
                previous_kernel_results,
            ],
            parallel_iterations=parallel_iterations)

        return [current_state, ais_weights, kernel_results]
Пример #17
0
def option_price_binomial(*,
                          volatilities,
                          strikes,
                          expiries,
                          spots,
                          discount_rates=None,
                          dividend_rates=None,
                          is_call_options=None,
                          is_american=None,
                          num_steps=100,
                          dtype=None,
                          name=None):
    """Computes the BS price for a batch of European or American options.

  Uses the Cox-Ross-Rubinstein version of the binomial tree method to compute
  the price of American or European options. Supports batching of the options
  and allows mixing of European and American style exercises in a batch.
  For more information about the binomial tree method and the
  Cox-Ross-Rubinstein method in particular see the references below.

  #### Example

  ```python
  # Prices 5 options with a mix of Call/Put, American/European features
  # in a single batch.
  dtype = np.float64
  spots = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=dtype)
  strikes = np.array([3.0, 3.0, 3.0, 3.0, 3.0], dtype=dtype)
  volatilities = np.array([0.1, 0.22, 0.32, 0.01, 0.4], dtype=dtype)
  is_call_options = np.array([True, True, False, False, False])
  is_american = np.array([False, True, True, False, True])
  discount_rates = np.array(0.035, dtype=dtype)
  dividend_rates = np.array([0.02, 0.0, 0.07, 0.01, 0.0], dtype=dtype)
  expiries = np.array(1.0, dtype=dtype)

  prices = option_price_binomial(
      volatilities=volatilities,
      strikes=strikes,
      expiries=expiries,
      spots=spots,
      discount_rates=discount_rates,
      dividend_rates=dividend_rates,
      is_call_options=is_call_options,
      is_american=is_american,
      dtype=dtype)
  # Prints [0., 0.0098847, 0.41299509, 0., 0.06046989]
  ```

  #### References

  [1] Hull, John C., Options, Futures and Other Derivatives. Pearson, 2018.
  [2] Wikipedia contributors. Binomial Options Pricing Model. Available at:
    https://en.wikipedia.org/wiki/Binomial_options_pricing_model

  Args:
    volatilities: Real `Tensor` of any shape and dtype. The volatilities to
      expiry of the options to price.
    strikes: A real `Tensor` of the same dtype and compatible shape as
      `volatilities`. The strikes of the options to be priced.
    expiries: A real `Tensor` of same dtype and compatible shape as
      `volatilities`. The expiry of each option. The units should be such that
      `expiry * volatility**2` is dimensionless.
    spots: A real `Tensor` of any shape that broadcasts to the shape of the
      `volatilities`. The current spot price of the underlying.
    discount_rates: An optional real `Tensor` of same dtype as the
      `volatilities`. The risk free discount rate. If None the rate is assumed
      to be 0.
      Default value: None, equivalent to discount rates = 0..
    dividend_rates: An optional real `Tensor` of same dtype as the
      `volatilities`. If None the rate is assumed to be 0.
      Default value: None, equivalent to discount rates = 1.
    is_call_options: A boolean `Tensor` of a shape compatible with
      `volatilities`. Indicates whether the option is a call (if True) or a put
      (if False). If not supplied, call options are assumed.
      Default value: None, equivalent to is_call_options = True.
    is_american: A boolean `Tensor` of a shape compatible with `volatilities`.
      Indicates whether the option exercise style is American (if True) or
      European (if False). If not supplied, European style exercise is assumed.
      Default value: None, equivalent to is_american = False.
    num_steps: A positive scalar int32 `Tensor`. The size of the time
      discretization to use.
      Default value: 100.
    dtype: Optional `tf.DType`. If supplied, the dtype to be used for conversion
      of any supplied non-`Tensor` arguments to `Tensor`.
      Default value: None which maps to the default dtype inferred by TensorFlow
        (float32).
    name: str. The name for the ops created by this function.
      Default value: None which is mapped to the default name `option_price`.

  Returns:
    A `Tensor` of the same shape as the inferred batch shape of the input data.
    The Black Scholes price of the options computed on a binomial tree.
  """
    with tf.name_scope(name or 'crr_option_price'):
        strikes = tf.convert_to_tensor(strikes, dtype=dtype, name='strikes')
        dtype = strikes.dtype
        volatilities = tf.convert_to_tensor(volatilities,
                                            dtype=dtype,
                                            name='volatilities')
        expiries = tf.convert_to_tensor(expiries, dtype=dtype, name='expiries')
        spots = tf.convert_to_tensor(spots, dtype=dtype, name='spots')

        if discount_rates is None:
            discount_rates = tf.zeros_like(volatilities)
        else:
            discount_rates = tf.convert_to_tensor(discount_rates,
                                                  dtype=dtype,
                                                  name='discount_rates')
        if dividend_rates is None:
            dividend_rates = tf.zeros_like(volatilities)
        else:
            dividend_rates = tf.convert_to_tensor(dividend_rates,
                                                  dtype=dtype,
                                                  name='dividend_rates')
        if is_call_options is None:
            is_call_options = tf.ones_like(volatilities,
                                           dtype=tf.bool,
                                           name='is_call_options')
        else:
            is_call_options = tf.convert_to_tensor(is_call_options,
                                                   dtype=tf.bool,
                                                   name='is_call_options')
        if is_american is None:
            is_american = tf.zeros_like(volatilities,
                                        dtype=tf.bool,
                                        name='is_american')
        else:
            is_american = tf.convert_to_tensor(is_american,
                                               dtype=tf.bool,
                                               name='is_american')

        num_steps = tf.cast(num_steps, dtype=dtype)
        dt = expiries / num_steps

        # CRR choices for the up and down move multipliers
        ln_up = volatilities * tf.math.sqrt(dt)
        ln_dn = -ln_up

        # Prepares the spot grid.
        grid_idx = tf.range(num_steps + 1)
        # Stores the grid as shape [input_batch, N + 1] where N = num_steps.
        log_spot_grid_1 = tf.expand_dims(tf.math.log(spots) +
                                         ln_up * num_steps,
                                         axis=-1)
        log_spot_grid_2 = tf.expand_dims(ln_dn - ln_up, axis=-1) * grid_idx
        log_spot_grid = log_spot_grid_1 + log_spot_grid_2

        # Adding the new dimension is to ensure that batch shape is at the front.
        payoff_fn = _get_payoff_fn(tf.expand_dims(strikes, axis=-1),
                                   tf.expand_dims(is_call_options, axis=-1))
        value_mod_fn = _get_value_modifier(
            tf.expand_dims(is_american, axis=-1), payoff_fn)

        # Shape [batch shape, num time steps + 1]
        values_grid = payoff_fn(tf.math.exp(log_spot_grid))

        p_up = tf.math.exp((discount_rates - dividend_rates) * dt + ln_up) - 1
        p_up /= tf.math.exp(2 * ln_up) - 1
        p_up = tf.expand_dims(p_up, axis=-1)
        p_dn = 1 - p_up
        discount_factors = tf.expand_dims(tf.math.exp(-discount_rates * dt),
                                          axis=-1)
        ln_up = tf.expand_dims(ln_up, axis=-1)

        def one_step_back(current_values, current_log_spot_grid):
            next_values = (current_values[..., 1:] * p_dn +
                           current_values[..., :-1] * p_up)
            next_log_spot_grid = current_log_spot_grid[..., :-1] - ln_up
            next_values = value_mod_fn(next_values,
                                       tf.math.exp(next_log_spot_grid))
            return discount_factors * next_values, next_log_spot_grid

        def should_continue(current_values, current_log_spot_grid):
            del current_values, current_log_spot_grid
            return True

        batch_shape = values_grid.shape[:-1]
        pv, _ = tf.while_loop(
            should_continue,
            one_step_back, (values_grid, log_spot_grid),
            maximum_iterations=tf.cast(num_steps, dtype=tf.int32),
            shape_invariants=(tf.TensorShape(batch_shape + [None]),
                              tf.TensorShape(batch_shape + [None])))
        return tf.squeeze(pv, axis=-1)
Пример #18
0
def _solve(
    time_direction_fn,
    start_time,
    end_time,
    coord_grid,
    values_grid,
    num_steps=None,
    start_step_count=0,
    time_step=None,
    one_step_fn=None,
    boundary_conditions=None,
    values_transform_fn=None,
    second_order_coeff_fn=None,
    first_order_coeff_fn=None,
    zeroth_order_coeff_fn=None,
    inner_second_order_coeff_fn=None,
    inner_first_order_coeff_fn=None,
    maximum_steps=None,
    swap_memory=True,
    name=None):
  """Common code for solve_backward and solve_forward."""
  if (num_steps is None) == (time_step is None):
    raise ValueError('Exactly one of num_steps or time_step'
                     ' should be supplied.')
  coord_grid = [
      tf.convert_to_tensor(dim_grid, dtype=values_grid.dtype)
      for dim_grid in coord_grid
  ]

  n_dims = len(coord_grid)
  if one_step_fn is None:
    if n_dims == 1:
      one_step_fn = oscillation_damped_crank_nicolson_step()
    else:
      one_step_fn = douglas_adi_step(theta=0.5)

  if boundary_conditions is None:

    def zero_dirichlet(t, grid):
      del t, grid
      return 1, None, tf.constant(0, dtype=values_grid.dtype)

    boundary_conditions = [(zero_dirichlet, zero_dirichlet)] * n_dims

  with tf.compat.v1.name_scope(
      name,
      default_name='solve',
      values=[
          start_time,
          end_time,
          coord_grid,
          values_grid,
          num_steps,
          time_step,
      ]):
    time_step_fn, est_max_steps = _get_time_steps_info(start_time, end_time,
                                                       num_steps, time_step,
                                                       time_direction_fn)
    if est_max_steps is None and maximum_steps is not None:
      est_max_steps = maximum_steps

    def loop_cond(should_stop, time, x_grid, f_grid, steps_performed):
      del time, x_grid, f_grid, steps_performed
      return tf.logical_not(should_stop)

    def loop_body(should_stop, time, x_grid, f_grid, steps_performed):
      """Propagates the grid in time."""
      del should_stop
      next_should_stop, t_next = time_step_fn(time)
      next_xs, next_fs = one_step_fn(
          time=time,
          next_time=t_next,
          coord_grid=x_grid,
          value_grid=f_grid,
          boundary_conditions=boundary_conditions,
          second_order_coeff_fn=second_order_coeff_fn,
          first_order_coeff_fn=first_order_coeff_fn,
          zeroth_order_coeff_fn=zeroth_order_coeff_fn,
          inner_second_order_coeff_fn=inner_second_order_coeff_fn,
          inner_first_order_coeff_fn=inner_first_order_coeff_fn,
          num_steps_performed=steps_performed)

      if values_transform_fn is not None:
        next_xs, next_fs = values_transform_fn(t_next, next_xs, next_fs)
      return next_should_stop, t_next, next_xs, next_fs, steps_performed + 1

    # If the start time is already equal to end time, no stepping is needed.
    # solve_backward, solve_forward already took care of the case when end_time
    # is on the "wrong side" of start_time.
    should_already_stop = (start_time == end_time)
    initial_args = (should_already_stop, start_time, coord_grid, values_grid,
                    start_step_count)
    (_, final_time, final_coords, final_values,
     steps_performed) = tf.while_loop(
         loop_cond,
         loop_body,
         initial_args,
         swap_memory=swap_memory,
         maximum_iterations=est_max_steps)
    return final_values, final_coords, final_time, steps_performed
Пример #19
0
def make_convolution_transpose_fn_with_subkernels_matrix(
        filter_shape,
        strides,
        padding,
        rank=2,
        dilations=None,
        dtype=tf.int32,
        validate_args=False,
        name=None):
    """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
    with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'):

        if tf.get_static_value(rank) != 2:
            raise NotImplementedError(
                'Argument `rank` currently only supports `2`; '
                'saw "{}".'.format(rank))

        strides = tf.get_static_value(strides)
        if not isinstance(strides, int):
            raise ValueError(
                'Argument `strides` must be a statically known integer.'
                'Saw: {}'.format(strides))

        [
            filter_shape,
            rank,
            _,
            padding,
            dilations,
        ] = prepare_conv_args(filter_shape,
                              rank=rank,
                              strides=strides,
                              padding=padding,
                              dilations=dilations,
                              is_transpose=True,
                              validate_args=validate_args)

        fh, fw = filter_shape
        dh, dw = dilations

        # Determine maximum filter height and filter width of sub-kernels.
        sub_fh = (fh - 1) // strides + 1
        sub_fw = (fw - 1) // strides + 1

        def loop_body(i_, event_ind):
            i = i_ // strides
            j = i_ % strides

            i_ind = ps.range(i * fw,
                             ps.maximum(i, fh) * fw,
                             delta=strides * fw,
                             dtype=dtype)
            j_ind = ps.range(j, ps.maximum(j, fw), delta=strides, dtype=dtype)

            nc = cartesian_add([i_ind, j_ind])
            ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0])

            k = ps.reshape(cartesian_add([
                ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype),
                ps.range(ps.shape(nc)[1], dtype=dtype)
            ]),
                           shape=[-1])
            last_j = strides - (fw - j - 1) % strides - 1
            last_i = strides - (fh - i - 1) % strides - 1
            kernel_ind = ps.stack(
                [k, ps.ones_like(k) * last_i * strides + last_j], axis=1)
            event_ind = ps.tensor_scatter_nd_update(event_ind, ind[...,
                                                                   tf.newaxis],
                                                    kernel_ind)

            return i_ + 1, event_ind

        event_ind = ps.zeros((fh * fw, 2), dtype=dtype)
        _, event_ind = tf.while_loop(lambda i, _: i < strides**2, loop_body,
                                     [tf.zeros([], dtype=dtype), event_ind])

        tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding(
            fh, stride=strides, dilation=dh, padding=padding)
        tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding(
            fw, stride=strides, dilation=dw, padding=padding)

        pad_bottom = (tot_pad_bottom - 1) // strides + 1
        pad_top = (tot_pad_top - 1) // strides + 1
        pad_right = (tot_pad_right - 1) // strides + 1
        pad_left = (tot_pad_left - 1) // strides + 1
        padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right))

        truncate_top = pad_top * strides - tot_pad_top
        truncate_left = pad_left * strides - tot_pad_left

        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):

                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x,
                                                  kernel=kernel,
                                                  filter_shape=filter_shape,
                                                  strides=(strides, ) * rank,
                                                  padding=padding,
                                                  dilations=dilations,
                                                  c_out=c_out,
                                                  batch_shape=batch_shape,
                                                  event_shape=event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                x_pad_shape = ps.shape(x_pad)[:-3]
                flat_shape = ps.pad(x_pad_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.reshape(x_pad, shape=flat_shape)

                idx, s = im2row_index(
                    (xh + tf.reduce_sum(padding_vals[0]),
                     xw + tf.reduce_sum(padding_vals[1]), c_in),
                    block_shape=(sub_fh, sub_fw),
                    slice_step=(1, 1),
                    dilations=dilations)

                x_ = tf.gather(flat_x, indices=idx, axis=-1)
                im_x = tf.reshape(x_,
                                  shape=ps.concat([x_pad_shape, s], axis=0))

                # Add channels to subkernel indices
                idx_event = event_ind * [[c_in, 1]]
                idx_event_channels = (idx_event[tf.newaxis] + tf.stack(
                    [ps.range(c_in),
                     tf.zeros(
                         (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :])
                idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels,
                                                         block_shape=[c_in],
                                                         crops=[[0, 0]]),
                                       axis=0)
                idx_event_broadcast = tf.broadcast_to(
                    idx_event,
                    shape=ps.concat(
                        [kernel_batch, ps.shape(idx_event)], axis=0))

                # Add cartesian product of batch indices, since scatter_nd can only be
                # applied to leading dimensions.
                idx_batch = tf.stack(tf.meshgrid(*[
                    ps.range(b_, delta=1, dtype=dtype)
                    for b_ in tf.unstack(kernel_batch)
                ],
                                                 indexing='ij'),
                                     axis=ps.size(kernel_batch))

                idx_batch = tf.cast(idx_batch,
                                    dtype=dtype)  # empty tensor is float

                idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros(
                    (ps.shape(idx_event)[0], 1), dtype=dtype)
                idx_kernel = tf.concat(
                    [idx_batch_broadcast, idx_event_broadcast], axis=-1)

                kernel_mat = tf.scatter_nd(
                    idx_kernel,
                    updates=kernel,
                    shape=ps.cast(ps.concat([
                        kernel_batch,
                        [sub_fh * sub_fw * c_in, strides**2, c_out]
                    ],
                                            axis=0),
                                  dtype=dtype))

                kernel_mat = tf.reshape(
                    kernel_mat,
                    shape=ps.concat(
                        [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]],
                        axis=0))

                kernel_mat = kernel_mat[..., tf.newaxis, :, :]
                out = tf.matmul(im_x, kernel_mat)
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)

                if strides > 1:
                    tot_size = tf.reduce_prod(broadcast_batch_shape)
                    flat_out = tf.reshape(out,
                                          shape=ps.concat([[tot_size],
                                                           ps.shape(out)[-3:]],
                                                          axis=0))
                    out = tf.nn.depth_to_space(flat_out, block_size=strides)

                out_height = _deconv_output_length(xh,
                                                   filter_size=fh,
                                                   padding=padding,
                                                   output_padding=None,
                                                   stride=strides,
                                                   dilation=dh)
                out_width = _deconv_output_length(xw,
                                                  filter_size=fw,
                                                  padding=padding,
                                                  output_padding=None,
                                                  stride=strides,
                                                  dilation=dw)

                out = out[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
                out = tf.reshape(
                    out,
                    shape=ps.concat([
                        broadcast_batch_shape, [out_height, out_width, c_out]
                    ],
                                    axis=0))
                return out

        return op
Пример #20
0
def sinkhorn_iterations(x,
                        y,
                        a,
                        b,
                        power=2.0,
                        epsilon=1e-3,
                        epsilon_0=1e-1,
                        epsilon_decay=0.95,
                        threshold=1e-2,
                        inner_num_iter=5,
                        max_iterations=2000):
    """Runs the Sinkhorn's algorithm from (x, a) to (y, b).

  Args:
   x: Tensor<float>[batch, n]: the input point clouds.
   y: Tensor<float>[batch, m]: the target point clouds.
   a: Tensor<float>[batch, n]: the weight of each input point. The sum of all
     elements of b must match that of a to converge.
   b: Tensor<float>[batch, m]: the weight of each target point. The sum of all
     elements of b must match that of a to converge.
   power: (float) the power of the distance for the cost function.
   epsilon: (float) the level of entropic regularization wanted.
   epsilon_0: (float) the initial level of entropic regularization.
   epsilon_decay: (float) a multiplicative factor applied at each iteration
     until reaching the epsilon value.
   threshold: (float) the relative threshold on the Sinkhorn error to stop the
     Sinkhorn iterations.
   inner_num_iter: (int32) the Sinkhorn error is not recomputed at each
     iteration but every inner_num_iter instead to avoid computational overhead.
   max_iterations: (int32) the maximum number of Sinkhorn iterations.

  Returns:
   A 5-tuple containing: the values of the conjugate variables f and g, the
   final value of the entropic parameter epsilon, the cost matrix and the number
   of iterations.
  """
    max_outer_iterations = max_iterations // inner_num_iter
    loga = tf.math.log(a)
    logb = tf.math.log(b)
    cost, d_cost = cost_fn(x, y, power)

    def body_fn(f, g, eps, num_iter):
        for _ in range(inner_num_iter):
            g = eps * logb + softmin(cost, f, g, eps, axis=1) + g
            f = eps * loga + softmin(cost, f, g, eps, axis=2) + f
            eps = tf.math.maximum(eps * epsilon_decay, epsilon)
        return [f, g, eps, num_iter + inner_num_iter]

    def cond_fn(f, g, eps, num_iter):
        return tf.math.reduce_all([
            tf.math.less(num_iter, max_iterations),
            tf.math.reduce_any([
                tf.math.greater(eps, epsilon),
                tf.math.greater(error(cost, f, g, eps, b), threshold)
            ])
        ])

    f, g, eps, iterations = tf.while_loop(
        cond_fn,
        body_fn, [
            tf.zeros_like(loga),
            tf.zeros_like(logb),
            tf.cast(epsilon_0, dtype=x.dtype),
            tf.constant(0, dtype=tf.int32)
        ],
        parallel_iterations=1,
        maximum_iterations=max_outer_iterations + 1)

    return f, g, eps, cost, d_cost, iterations
Пример #21
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):
                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x, kernel, filter_shape,
                                                  strides, padding, dilations,
                                                  c_out, batch_shape,
                                                  event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)
                x_pad = tf.pad(x, paddings=paddings, constant_values=0)

                ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1
                ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1

                def loop_body(i, outputs):
                    subkernel_ind = kernels_ind.read(i)
                    fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2)
                    eh = ex_h + fh_ - 1
                    ew = ex_w + fw_ - 1

                    subkernel_ind = ps.reshape(ps.reshape(
                        subkernel_ind * c_in, shape=[-1])[:, tf.newaxis] +
                                               ps.range(c_in),
                                               shape=[-1])

                    k = tf.gather(kernel, subkernel_ind, axis=-2)
                    ind, shape = im2row_index([eh, ew, c_in],
                                              block_shape=(fh_, fw_),
                                              slice_step=(1, 1),
                                              dilations=dilations)
                    x_i = x_pad[..., :eh, :ew, :]
                    x_i_shape = ps.shape(x_i)
                    flat_shape = ps.pad(x_i_shape[:-3],
                                        paddings=[[0, 1]],
                                        constant_values=-1)
                    flat_x = tf.reshape(x_i, flat_shape)
                    x_ = tf.gather(flat_x, ind, axis=-1)
                    im_x = tf.reshape(
                        x_, ps.concat([x_i_shape[:-3], shape], axis=0))
                    outputs = outputs.write(
                        i,
                        tf.matmul(
                            im_x,
                            tf.reshape(
                                k,
                                ps.concat([
                                    kernel_batch, [1, fh_ * fw_ * c_in, c_out]
                                ],
                                          axis=0))))
                    return i + 1, outputs

                outputs = tf.TensorArray(dtype=input_dtype, size=sh * sw)

                _, outputs = tf.while_loop(lambda i, _: i < sh * sw, loop_body,
                                           [0, outputs])

                y = outputs.concat()

                m = tf.reduce_prod(ps.shape(y)[:-3])
                y_ = tf.reshape(y,
                                shape=ps.concat([[m], ps.shape(y)[-3:]],
                                                axis=0))
                y2 = tf.batch_to_space(y_,
                                       strides,
                                       crops=tf.zeros([2, 2], dtype=tf.int64))
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)
                y2 = tf.reshape(
                    y2,
                    ps.concat([broadcast_batch_shape,
                               ps.shape(y2)[-3:]],
                              axis=0))

                out_height = _deconv_output_length(xh,
                                                   filter_size=fh,
                                                   padding=padding,
                                                   output_padding=None,
                                                   stride=sh,
                                                   dilation=dh)
                out_width = _deconv_output_length(xw,
                                                  filter_size=fw,
                                                  padding=padding,
                                                  output_padding=None,
                                                  stride=sw,
                                                  dilation=dw)

                return y2[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
Пример #22
0
    def rejection_sample_with_gradient(concentration):
        """Performs rejection sampling for standardized von Mises.

    A nested function is required because @tf.custom_gradient does not handle
    non-tensor inputs such as dtype. Instead, they are captured by the outer
    scope.

    Arguments:
      concentration: The concentration parameter of the distribution.

    Returns:
      Differentiable samples of standardized von Mises.
    """
        r = 1. + tf.sqrt(1. + 4. * concentration**2)
        rho = (r - tf.sqrt(2. * r)) / (2. * concentration)

        s_exact = (1. + rho**2) / (2. * rho)

        # For low concentration, s becomes numerically unstable.
        # To fix that, we use an approximation. Here is the derivation.
        # First-order Taylor expansion at conc = 0 gives
        #   sqrt(1 + 4 concentration^2) ~= 1 + (2 concentration)^2 / 2.
        # Therefore, r ~= 2 + 2 concentration. By plugging this into rho, we have
        #   rho ~= conc + 1 / conc - sqrt(1 + 1 / concentration^2).
        # Let's expand the last term at concentration=0 up to the linear term:
        #   sqrt(1 + 1 / concentration^2) ~= 1 / concentration + concentration / 2
        # Thus, rho ~= concentration / 2. Finally,
        #   s = 1 / (2 rho) + rho / 2 ~= 1 / concentration + concentration / 4.
        # Since concentration is small, we drop the second term and simply use
        #   s ~= 1 / concentration.
        s_approximate = 1. / concentration

        # To compute the cutoff, we compute s_exact using mpmath with 30 decimal
        # digits precision and compare that to the s_exact and s_approximate
        # computed with dtype. Then, the cutoff is the largest concentration for
        # which abs(s_exact - s_exact_mpmath) > abs(s_approximate - s_exact_mpmath).
        s_concentration_cutoff_dict = {
            tf.float16: 1.8e-1,
            tf.float32: 2e-2,
            tf.float64: 1.2e-4,
        }
        s_concentration_cutoff = s_concentration_cutoff_dict[dtype]

        s = tf.where(concentration > s_concentration_cutoff, s_exact,
                     s_approximate)

        def loop_body(done, u, w, seed):
            """Resample the non-accepted points."""
            # We resample u each time completely. Only its sign is used outside the
            # loop, which is random.
            u_seed, v_seed, next_seed = samplers.split_seed(seed, n=3)
            u = samplers.uniform(shape,
                                 minval=-1.,
                                 maxval=1.,
                                 dtype=dtype,
                                 seed=u_seed)
            z = tf.cos(np.pi * u)
            # Update the non-accepted points.
            w = tf.where(done, w, (1. + s * z) / (s + z))
            y = concentration * (s - w)

            v = samplers.uniform(shape,
                                 minval=0.,
                                 maxval=1.,
                                 dtype=dtype,
                                 seed=v_seed)
            accept = (y * (2. - y) >= v) | (tf.math.log(y / v) + 1. >= y)

            return done | accept, u, w, next_seed

        _, u, w, _ = tf.while_loop(
            cond=lambda done, *_: ~tf.reduce_all(done),
            body=loop_body,
            loop_vars=(
                tf.zeros(shape, dtype=tf.bool, name='done'),
                tf.zeros(shape, dtype=dtype, name='u'),
                tf.zeros(shape, dtype=dtype, name='w'),
                seed,
            ),
            # The expected number of iterations depends on concentration.
            # It monotonically increases from one iteration for concentration = 0 to
            # sqrt(2 pi / e) ~= 1.52 iterations for concentration = +inf [1].
            # We use a limit of 100 iterations to avoid infinite loops
            # for very large / nan concentration.
            maximum_iterations=100,
        )

        x = tf.sign(u) * tf.math.acos(w)

        def grad(dy):
            """The gradient of the von Mises samples w.r.t. concentration."""
            broadcast_concentration = tf.broadcast_to(concentration,
                                                      prefer_static.shape(x))
            _, dcdf_dconcentration = value_and_gradient(
                lambda conc: von_mises_cdf(x, conc), broadcast_concentration)
            inv_prob = tf.exp(-broadcast_concentration * (tf.cos(x) - 1.)) * (
                (2. * np.pi) * tf.math.bessel_i0e(broadcast_concentration))
            # Compute the implicit reparameterization gradient [2],
            # dz/dconc = -(dF(z; conc) / dconc) / p(z; conc)
            ret = dy * (-inv_prob * dcdf_dconcentration)
            # Sum over the sample dimensions. Assume that they are always the first
            # ones.
            num_sample_dimensions = (tf.rank(broadcast_concentration) -
                                     tf.rank(concentration))
            return tf.reduce_sum(ret, axis=tf.range(num_sample_dimensions))

        return x, grad
Пример #23
0
    def _sample_n(self, n, seed=None):
        seed = seed_stream.SeedStream(seed, salt='vom_mises_fisher')
        # The sampling strategy relies on the fact that vMF variates are symmetric
        # about the mean direction. Accordingly, if we have a sampling strategy for
        # the away-from-mean angle, then we can uniformly sample the remaining
        # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
        # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
        #
        # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
        # von-Mises distributed `x` value in [-1, 1], then uniformly select what
        # amounts to a "up" or "down" additional degree of freedom after unit
        # normalizing, followed by a final rotation to the desired mean direction
        # from a basis of (1, 0).
        #
        # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
        # unit sphere over which the distribution is uniform, in particular the
        # circle where x = \hat{x} intersects the unit sphere. We pick a point on
        # that circle, then rotate to the desired mean direction from a basis of
        # (1, 0, 0).
        event_dim = (tf.compat.dimension_value(self.event_shape[0])
                     or self._event_shape_tensor()[0])

        sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()],
                                       axis=0)
        dim = tf.cast(event_dim - 1, self.dtype)
        if event_dim == 3:
            samples_dim0 = self._sample_3d(n, seed=seed)
        else:
            # Wood'94 provides a rejection algorithm to sample the x coordinate.
            # Wood'94 definition of b:
            # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
            # https://stats.stackexchange.com/questions/156729 suggests:
            b = dim / (2 * self.concentration +
                       tf.sqrt(4 * self.concentration**2 + dim**2))
            # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
            #     https://github.com/nicola-decao/s-vae-tf/
            x = (1 - b) / (1 + b)
            c = self.concentration * x + dim * tf.math.log1p(-x**2)
            beta = beta_lib.Beta(dim / 2, dim / 2)

            def cond_fn(w, should_continue):
                del w
                return tf.reduce_any(input_tensor=should_continue)

            def body_fn(w, should_continue):
                z = beta.sample(sample_shape=sample_batch_shape, seed=seed())
                w = tf.where(should_continue,
                             (1 - (1 + b) * z) / (1 - (1 - b) * z), w)
                w = tf.debugging.check_numerics(w, 'w')
                should_continue = tf.logical_and(
                    should_continue,
                    self.concentration * w + dim * tf.math.log1p(-x * w) - c <
                    tf.math.log(
                        tf.random.uniform(sample_batch_shape,
                                          seed=seed(),
                                          dtype=self.dtype)))
                return w, should_continue

            w = tf.zeros(sample_batch_shape, dtype=self.dtype)
            should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
            samples_dim0 = tf.while_loop(cond=cond_fn,
                                         body=body_fn,
                                         loop_vars=(w, should_continue))[0]
            samples_dim0 = samples_dim0[..., tf.newaxis]
        if not self._allow_nan_stats:
            # Verify samples are w/in -1, 1, with useful error output tensors (top
            # value rather than all values).
            with tf.control_dependencies([
                    assert_util.assert_less_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(1.01),
                        data=[tf.nn.top_k(tf.reshape(samples_dim0, [-1]))[0]]),
                    assert_util.assert_greater_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(-1.01),
                        data=[
                            -tf.nn.top_k(tf.reshape(-samples_dim0, [-1]))[0]
                        ])
            ]):
                samples_dim0 = tf.identity(samples_dim0)
        samples_otherdims_shape = tf.concat(
            [sample_batch_shape, [event_dim - 1]], axis=0)
        unit_otherdims = tf.nn.l2_normalize(tf.random.normal(
            samples_otherdims_shape, seed=seed(), dtype=self.dtype),
                                            axis=-1)
        samples = tf.concat(
            [
                samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
                tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
            ],
            axis=-1)
        samples = tf.nn.l2_normalize(samples, axis=-1)
        if not self._allow_nan_stats:
            samples = tf.debugging.check_numerics(samples, 'samples')

        # Runtime assert that samples are unit length.
        if not self._allow_nan_stats:
            worst, idx = tf.nn.top_k(
                tf.reshape(tf.abs(1 - tf.linalg.norm(tensor=samples, axis=-1)),
                           [-1]))
            with tf.control_dependencies([
                    assert_util.assert_near(
                        dtype_util.as_numpy_dtype(self.dtype)(0),
                        worst,
                        data=[
                            worst, idx,
                            tf.gather(tf.reshape(samples, [-1, event_dim]),
                                      idx)
                        ],
                        atol=1e-4,
                        summarize=100)
            ]):
                samples = tf.identity(samples)
        # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
        # Now, we move the mode to `self.mean_direction` using a rotation matrix.
        if not self._allow_nan_stats:
            # Assert that the basis vector rotates to the mean direction, as expected.
            basis = tf.cast(
                tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                self.dtype)
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.linalg.norm(tensor=self._rotate(basis) -
                                       self.mean_direction,
                                       axis=-1),
                        dtype_util.as_numpy_dtype(self.dtype)(1e-5))
            ]):
                return self._rotate(samples)
        return self._rotate(samples)
Пример #24
0
def _build_discount_curve(bond_cashflows, bond_cashflow_times, present_values,
                          pv_settle_times, initial_discount_rates,
                          discount_tolerance, maximum_iterations):
    """Estimates the discount curve.

  The procedure is recursive and as follows:
  1. Assume some initial set of discount rates/discount factors.
    Set this as the current yield curve.
  2. From the current yield curve, interpolate to get the discount rates
    for each time at which bond_cashflows occur.
  3. Using these discounts and the known bond prices, compute the discount
    rate to expiry of each bond by inverting the bond pricing formula as
    follows. We know that the bond price satisfies (`P` is the present value,
    `r_i` is the discount rate to time `t_i`, `c_i` is the cashflow occurring at
    time `t_i`.):

    ```None
      P e^{-r_0 t_0} = c_1 e^{-r_1 t_1} + ... + c_n e^{-r_n t_n}        (A)

    ```
    Assuming we have estimated r_0, r_1, r_2, ..., r_{n-1}, we can invert the
    above equation to calculate r_n. We write this in a suggestive form
    suitable for the implementation below.

    ```None
      -c_n z_n = -P z_0 + c_1 z_1 + c_2 z_2 + ... + c_{n-1} z_{n-1}     (B)

    ```
    where

    ```None
      z_i = e^{-r_i t_i}      (C)

    ```
    The RHS of Eq. (B) looks like the PV of cashflows
    `[-P, c_1, c_2, ... c_{n-1}]` paid out at times `[t_0, t_1, ..., t_{n-1}]`.

    Concatenate these "synthetic" cashflow times for each bond:

    `Ts = [t1_0, t1_1, ... t1_{n1-1}] + [t2_0, t2_1, ... t2_{n2-1}] ...`

    Also concatenate the synthetic bond cashflows as:

    `Cs = [-P1, c1_1, ..., c1_{n1-1}] + [-P2, c2_1, ..., c2_{n2-1}] ...`

    Then compute `Rs = InterpolateRates[Ts], Zs = exp(-Rs * Ts)`

    Let `Zns = [z_n1, z_n2, ... ], Cns = [c1_n, c2_n, ...]` be the discount
    factors to expiry and the final cashflow of each bond.
    We can derive `Zns = - SegmentSum(Cs * Zs) / Cns`.

    From that, we get Rns = -log(Zns) / Tns.
    Using this as the next guess for the discount rates and we repeat the
    procedure from Step (1) until convergence.

  Args:
    bond_cashflows: List of `Tensor`s. Each `Tensor` must be of rank 1 and of
      the same real dtype. They may be of different sizes. Each `Tensor`
      represents the bond cashflows defining a particular bond. The elements of
      the list are the bonds to be used to build the curve.
    bond_cashflow_times: List of `Tensor`s. The list must be of the same length
      as the `bond_cashflows` and each `Tensor` in the list must be of the same
      length as the `Tensor` at the same index in the `bond_cashflows` list.
      Each `Tensor` must be of rank 1 and of the same dtype as the `Tensor`s in
      `bond_cashflows` and contain strictly positive and increasing values. The
      times of the bond cashflows for the bonds must in an ascending order.
    present_values: List containing scalar `Tensor`s of the same dtype as
      elements of `bond_cashflows`. The length of the list must be the same as
      the length of `bond_cashflows`. The market price (i.e the all-in or dirty
      price) of the bond cashflows supplied in the `bond_cashflows`.
    pv_settle_times:   List containing scalar `Tensor`s of the same dtype as
      elements of `bond_cashflows`. The length of the list must be the same as
      the length of `bond_cashflows`. The settlement times for the present
      values is the time from now when the bond is traded to the time that the
      purchase price is actually delivered.
    initial_discount_rates: Rank 1 `Tensor` of same shape and dtype as
      `pv_settle_times`. The initial guess for the discount rates to bond expiry
      times.
    discount_tolerance: Positive scalar `Tensor` of same dtype as
      `initial_discount_factors`. The absolute tolerance for terminating the
      iterations used to fit the rate curve. The iterations are stopped when the
      estimated discounts at the expiry times of the bond cashflows change by a
      amount smaller than `discount_tolerance` in an iteration.
    maximum_iterations: Positive scalar `tf.int32` `Tensor`. The maximum number
      of iterations permitted.

  Returns:
    curve_builder_result: An instance of `CurveBuilderResult` containing the
      following attributes.
      times: Rank 1 real `Tensor`. Times for the computed discount rates.
      discount_rates: Rank 1 `Tensor` of the same dtype as `times`.
        The inferred discount rates.
      discount_factor: Rank 1 `Tensor` of the same dtype as `times`.
        The inferred discount factors.
      initial_discount_rates: Rank 1 `Tensor` of the same dtype as `times`. The
        initial guess for the discount rates.
      converged: Scalar boolean `Tensor`. Whether the procedure converged.
        The procedure is said to have converged when the maximum absolute
        difference in the discount factors from one iteration to the next falls
        below the `discount_tolerance`.
      failed: Scalar boolean `Tensor`. Whether the procedure failed. Procedure
        may fail either because a NaN value was encountered for the discount
        rates or the discount factors.
      iterations: Scalar `tf.int32` `Tensor`. Number of iterations performed.
  """
    calc_bond_cashflows = []  # Cs
    calc_times = []  # Ts
    expiry_times = []  # Tns
    expiry_bond_cashflows = []  # Cns
    calc_groups = []
    num_bonds = len(bond_cashflows)
    for i in range(num_bonds):
        calc_bond_cashflows.extend([[-present_values[i]],
                                    bond_cashflows[i][:-1]])
        calc_times.extend([[pv_settle_times[i]], bond_cashflow_times[i][:-1]])
        expiry_times.append(bond_cashflow_times[i][-1])
        expiry_bond_cashflows.append(bond_cashflows[i][-1])
        calc_groups.append(tf.fill(tf.shape(bond_cashflows[i]), i))

    calc_bond_cashflows = tf.concat(calc_bond_cashflows, axis=0)
    calc_times = tf.concat(calc_times, axis=0)
    expiry_times = tf.stack(expiry_times, axis=0)
    expiry_bond_cashflows = tf.stack(expiry_bond_cashflows, axis=0)
    calc_groups = tf.concat(calc_groups, axis=0)

    def one_step(converged, failed, iteration, expiry_discounts):
        """One step of the iteration."""
        expiry_rates = -tf.math.log(expiry_discounts) / expiry_times
        failed = tf.math.reduce_any(
            tf.math.is_nan(expiry_rates) | tf.math.is_nan(expiry_discounts))
        calc_rates = monotone_convex.interpolate_yields(calc_times,
                                                        expiry_times,
                                                        yields=expiry_rates)
        calc_discounts = tf.math.exp(-calc_rates * calc_times)
        next_expiry_discounts = -tf.math.segment_sum(
            calc_bond_cashflows * calc_discounts,
            calc_groups) / expiry_bond_cashflows
        discount_diff = tf.math.abs(next_expiry_discounts - expiry_discounts)
        converged = (~tf.math.reduce_any(tf.math.is_nan(discount_diff)) &
                     (tf.math.reduce_max(discount_diff) < discount_tolerance))
        return converged, failed, iteration + 1, next_expiry_discounts

    def cond(converged, failed, iteration, expiry_discounts):
        del expiry_discounts, iteration
        # Note we do not need to check iteration count here because that
        # termination mode is imposed by the maximum_iterations parameter in the
        # while loop.
        return ~tf.math.logical_or(converged, failed)

    initial_discount_factors = tf.math.exp(-initial_discount_rates *
                                           expiry_times)
    initial_vals = (False, False, 0, initial_discount_factors)
    loop_result = tf.while_loop(cond,
                                one_step,
                                initial_vals,
                                maximum_iterations=maximum_iterations)
    discount_factors = loop_result[-1]
    discount_rates = -tf.math.log(discount_factors) / expiry_times
    results = CurveBuilderResult(times=expiry_times,
                                 discount_rates=discount_rates,
                                 discount_factors=discount_factors,
                                 initial_discount_rates=initial_discount_rates,
                                 converged=loop_result[0],
                                 failed=loop_result[1],
                                 iterations=loop_result[2])
    return results
Пример #25
0
def _compute_general_continued_fraction(max_iterations,
                                        numerator_denominator_args_list,
                                        tolerance=None,
                                        partial_numerator_fn=None,
                                        partial_denominator_fn=None,
                                        dtype=tf.float32,
                                        name=None):
    """Compute a general continued fraction.

  Given at least one of `partial_numerator_fn` and `partial_denominator_fn`,
  compute the continued fraction associated with it via the forward recurrence.

  Let `a_i = partial_numerator_fn` and `b_i = partial_denominator_fn`. Then,
  this evaluates the infinite continued fraction:

  ```result = a_1 / (b_1 + a_2 / (b_2 + a_3 / (b_3 .....)```.

  If `partial_numerator_fn` or `partial_denominator_fn` are not given, then
  `a_i` (respectively `b_i`) are assumed to be 1. However one must be given.

  NOTE: Use this with caution. Forward recursion doesn't have numerical
  stability guarantees, compared to backward recursion.


  Args:
    max_iterations: Integer `Tensor` specifying the maximum number of terms to
      use.
    numerator_denominator_args_list: Arguments to pass in to
      `partial_numerator_fn` and `partial_denominator_fn`.
    tolerance: Float `Tensor` specifying the maximum acceptable tolerance
      between convergents. If unset, convergence is dictated by the number
      of iterations.
      Default value: `None`.
    partial_numerator_fn: Python callable that takes in as its first argument
      the current iteration count (an integer >= 1), and a list of *args, and
      returns a `Tensor`. These are used as partial numerators for the
      continued fraction.
      Default value: `None`.
    partial_denominator_fn: Python callable that takes in as its first argument
      the current iteration count (an integer >= 1), and a list of *args, and
      returns a `Tensor`. These are used as partial denominators for the
      continued fraction.
      Default value: `None`.
    dtype: The default dtype of the continued fraction. Default: `float32`.
    name: A name for the operation (optional).
      Default value: `None` (i.e., 'continued_fraction').

  Returns:
    Continued fraction computed to `max_iterations` iterations and/or
    up to absolute error `tolerance`.

  #### References
  [1]: Walter Gautschi and Josef Slavik. On the Computation of Modified
       Bessel Function Ratios. http://www.jstor.com/stable/2006491
  """
    with tf.name_scope(name or 'continued_fraction'):
        dtype = dtype_util.common_dtype(numerator_denominator_args_list, dtype)

        if (partial_numerator_fn is None) and (partial_denominator_fn is None):
            raise ValueError('Expect one of `partial_numerator_fn` and '
                             '`partial_denominator_fn` to be set.')

        def _continued_fraction_one_step(unused_should_stop, numerator,
                                         previous_numerator, denominator,
                                         previous_denominator,
                                         iteration_count):
            partial_denominator = 1.
            if partial_denominator_fn:
                partial_denominator = partial_denominator_fn(
                    iteration_count, *numerator_denominator_args_list)
            new_numerator = partial_denominator * numerator
            new_denominator = partial_denominator * denominator

            partial_numerator = 1.
            if partial_numerator_fn:
                partial_numerator = partial_numerator_fn(
                    iteration_count, *numerator_denominator_args_list)
            new_numerator = new_numerator + partial_numerator * previous_numerator
            new_denominator = (new_denominator +
                               partial_numerator * previous_denominator)

            should_stop_next = iteration_count > max_iterations

            if tolerance is not None:
                # We can use a more efficient computation when the partial numerators
                # are 1.
                if partial_numerator_fn is None:
                    # We now want to compute to relative error between the fraction at
                    # this iteration, vs. the previous iteration.
                    # Let h_i be the numerator and k_i the denominator, and a_i be the
                    # i-th term.
                    # h_i / k_i - h_{i-1} / k_{i-1} =
                    # (h_i * k_{i - 1} - h_{i - 1} * k_i) / (k_i * k_{i - 1}) =
                    # ((a_i h_{i - 1} + h_{i - 2}) * k_{i - 1} -
                    # (a_i k_{i - 1} + k_{i - 2}) * h_{i - 1}) / (k_i * k_{i - 1}) =
                    # -(h_{i - 1} * k_{i - 2} - h_{i - 2} * k_{i - 1}) / (k_i * k_{i - 1})
                    # This suggests we should prove something about the numerator
                    # inductively, and indeed
                    # (h_i * k_{i - 1} - h_{i - 1} * k_i) = (-1)**i
                    delta = tf.math.reciprocal(new_denominator * denominator)
                # We actually need to compute the difference of fractions.
                else:
                    delta = new_numerator / new_denominator - numerator / denominator

                converged = tf.math.abs(delta) <= tolerance
                should_stop_next = tf.reduce_all(converged) | should_stop_next
            return (should_stop_next, new_numerator, numerator,
                    new_denominator, denominator, iteration_count + 1.)

        # This is to infer the correct shape of tensors
        if partial_denominator_fn:
            term = partial_denominator_fn(1., *numerator_denominator_args_list)
        else:
            term = partial_numerator_fn(1., *numerator_denominator_args_list)

        zeroth_numerator = tf.ones_like(term, dtype=dtype)
        zeroth_denominator = tf.zeros_like(term, dtype=dtype)
        first_numerator = tf.zeros_like(term, dtype=dtype)
        first_denominator = tf.ones_like(term, dtype=dtype)

        results = tf.while_loop(cond=lambda stop, *_: ~stop,
                                body=_continued_fraction_one_step,
                                loop_vars=(False, first_numerator,
                                           zeroth_numerator, first_denominator,
                                           zeroth_denominator,
                                           tf.cast(1., dtype=dtype)))
        return results[1] / results[3]
Пример #26
0
def minimize(objective_function,
             initial_simplex=None,
             initial_vertex=None,
             step_sizes=None,
             objective_at_initial_simplex=None,
             objective_at_initial_vertex=None,
             batch_evaluate_objective=False,
             func_tolerance=1e-8,
             position_tolerance=1e-8,
             parallel_iterations=1,
             max_iterations=None,
             reflection=None,
             expansion=None,
             contraction=None,
             shrinkage=None,
             name=None):
    """Minimum of the objective function using the Nelder Mead simplex algorithm.

  Performs an unconstrained minimization of a (possibly non-smooth) function
  using the Nelder Mead simplex method. Nelder Mead method does not support
  univariate functions. Hence the dimensions of the domain must be 2 or greater.
  For details of the algorithm, see
  [Press, Teukolsky, Vetterling and Flannery(2007)][1].

  Points in the domain of the objective function may be represented as a
  `Tensor` of general shape but with rank at least 1. The algorithm proceeds
  by modifying a full rank simplex in the domain. The initial simplex may
  either be specified by the user or can be constructed using a single vertex
  supplied by the user. In the latter case, if `v0` is the supplied vertex,
  the simplex is the convex hull of the set:

  ```None
  S = {v0} + {v0 + step_i * e_i}
  ```

  Here `e_i` is a vector which is `1` along the `i`-th axis and zero elsewhere
  and `step_i` is a characteristic length scale along the `i`-th axis. If the
  step size is not supplied by the user, a unit step size is used in every axis.
  Alternately, a single step size may be specified which is used for every
  axis. The most flexible option is to supply a bespoke step size for every
  axis.

  ### Usage:

  The following example demonstrates the usage of the Nelder Mead minimzation
  on a two dimensional problem with the minimum located at a non-differentiable
  point.

  ```python
    # The objective function
    def sqrt_quadratic(x):
      return tf.sqrt(tf.reduce_sum(x ** 2, axis=-1))

    start = tf.constant([6.0, -21.0])  # Starting point for the search.
    optim_results = tfp.optimizer.nelder_mead_minimize(
        sqrt_quadratic, initial_vertex=start, func_tolerance=1e-8,
        batch_evaluate_objective=True)

    # Check that the search converged
    assert(optim_results.converged)
    # Check that the argmin is close to the actual value.
    np.testing.assert_allclose(optim_results.position, np.array([0.0, 0.0]),
                                atol=1e-7)
    # Print out the total number of function evaluations it took.
    print("Function evaluations: %d" % optim_results.num_objective_evaluations)
  ```

  ### References:
  [1]: William Press, Saul Teukolsky, William Vetterling and Brian Flannery.
    Numerical Recipes in C++, third edition. pp. 502-507. (2007).
    http://numerical.recipes/cpppages/chap0sel.pdf

  [2]: Jeffrey Lagarias, James Reeds, Margaret Wright and Paul Wright.
    Convergence properties of the Nelder-Mead simplex method in low dimensions,
    Siam J. Optim., Vol 9, No. 1, pp. 112-147. (1998).
    http://www.math.kent.edu/~reichel/courses/Opt/reading.material.2/nelder.mead.pdf

  [3]: Fuchang Gao and Lixing Han. Implementing the Nelder-Mead simplex
    algorithm with adaptive parameters. Computational Optimization and
    Applications, Vol 51, Issue 1, pp 259-277. (2012).
    https://pdfs.semanticscholar.org/15b4/c4aa7437df4d032c6ee6ce98d6030dd627be.pdf

  Args:
    objective_function:  A Python callable that accepts a point as a
      real `Tensor` and returns a `Tensor` of real dtype containing
      the value of the function at that point. The function
      to be minimized. If `batch_evaluate_objective` is `True`, the callable
      may be evaluated on a `Tensor` of shape `[n+1] + s ` where `n` is
      the dimension of the problem and `s` is the shape of a single point
      in the domain (so `n` is the size of a `Tensor` representing a
      single point).
      In this case, the expected return value is a `Tensor` of shape `[n+1]`.
      Note that this method does not support univariate functions so the problem
      dimension `n` must be strictly greater than 1.
    initial_simplex: (Optional) `Tensor` of real dtype. The initial simplex to
      start the search. If supplied, should be a `Tensor` of shape `[n+1] + s`
      where `n` is the dimension of the problem and `s` is the shape of a
      single point in the domain. Each row (i.e. the `Tensor` with a given
      value of the first index) is interpreted as a vertex of a simplex and
      hence the rows must be affinely independent. If not supplied, an axes
      aligned simplex is constructed using the `initial_vertex` and
      `step_sizes`. Only one and at least one of `initial_simplex` and
      `initial_vertex` must be supplied.
    initial_vertex: (Optional) `Tensor` of real dtype and any shape that can
      be consumed by the `objective_function`. A single point in the domain that
      will be used to construct an axes aligned initial simplex.
    step_sizes: (Optional) `Tensor` of real dtype and shape broadcasting
      compatible with `initial_vertex`. Supplies the simplex scale along each
      axes. Only used if `initial_simplex` is not supplied. See description
      above for details on how step sizes and initial vertex are used to
      construct the initial simplex.
    objective_at_initial_simplex: (Optional) Rank `1` `Tensor` of real dtype
      of a rank `1` `Tensor`. The value of the objective function at the
      initial simplex. May be supplied only if `initial_simplex` is
      supplied. If not supplied, it will be computed.
    objective_at_initial_vertex: (Optional) Scalar `Tensor` of real dtype. The
      value of the objective function at the initial vertex. May be supplied
      only if the `initial_vertex` is also supplied.
    batch_evaluate_objective: (Optional) Python `bool`. If True, the objective
      function will be evaluated on all the vertices of the simplex packed
      into a single tensor. If False, the objective will be mapped across each
      vertex separately. Evaluating the objective function in a batch allows
      use of vectorization and should be preferred if the objective function
      allows it.
    func_tolerance: (Optional) Scalar `Tensor` of real dtype. The algorithm
      stops if the absolute difference between the largest and the smallest
      function value on the vertices of the simplex is below this number.
    position_tolerance: (Optional) Scalar `Tensor` of real dtype. The
      algorithm stops if the largest absolute difference between the
      coordinates of the vertices is below this threshold.
    parallel_iterations: (Optional) Positive integer. The number of iterations
      allowed to run in parallel.
    max_iterations: (Optional) Scalar positive `Tensor` of dtype `int32`.
      The maximum number of iterations allowed. If `None` then no limit is
      applied.
    reflection: (Optional) Positive Scalar `Tensor` of same dtype as
      `initial_vertex`. This parameter controls the scaling of the reflected
      vertex. See, [Press et al(2007)][1] for details. If not specified,
      uses the dimension dependent prescription of [Gao and Han(2012)][3].
    expansion: (Optional) Positive Scalar `Tensor` of same dtype as
      `initial_vertex`. Should be greater than `1` and `reflection`. This
      parameter controls the expanded scaling of a reflected vertex.
      See, [Press et al(2007)][1] for details. If not specified, uses the
      dimension dependent prescription of [Gao and Han(2012)][3].
    contraction: (Optional) Positive scalar `Tensor` of same dtype as
      `initial_vertex`. Must be between `0` and `1`. This parameter controls
      the contraction of the reflected vertex when the objective function at
      the reflected point fails to show sufficient decrease.
      See, [Press et al(2007)][1] for more details. If not specified, uses
      the dimension dependent prescription of [Gao and Han(2012][3].
    shrinkage: (Optional) Positive scalar `Tensor` of same dtype as
      `initial_vertex`. Must be between `0` and `1`. This parameter is the scale
      by which the simplex is shrunk around the best point when the other
      steps fail to produce improvements.
      See, [Press et al(2007)][1] for more details. If not specified, uses
      the dimension dependent prescription of [Gao and Han(2012][3].
    name: (Optional) Python str. The name prefixed to the ops created by this
      function. If not supplied, the default name 'minimize' is used.

  Returns:
    optimizer_results: A namedtuple containing the following items:
      converged: Scalar boolean tensor indicating whether the minimum was
        found within tolerance.
      num_objective_evaluations: The total number of objective
        evaluations performed.
      position: A `Tensor` containing the last argument value found
        during the search. If the search converged, then
        this value is the argmin of the objective function.
      objective_value: A tensor containing the value of the objective
        function at the `position`. If the search
        converged, then this is the (local) minimum of
        the objective function.
      final_simplex: The last simplex constructed before stopping.
      final_objective_values: The objective function evaluated at the
        vertices of the final simplex.
      initial_simplex: The starting simplex.
      initial_objective_values: The objective function evaluated at the
        vertices of the initial simplex.
      num_iterations: The number of iterations of the main algorithm body.

  Raises:
    ValueError: If any of the following conditions hold
      1. If none or more than one of `initial_simplex` and `initial_vertex` are
        supplied.
      2. If `initial_simplex` and `step_sizes` are both specified.
  """
    with tf.name_scope(name or 'minimize'):
        (dim, _, simplex, objective_at_simplex,
         num_evaluations) = _prepare_args(objective_function, initial_simplex,
                                          initial_vertex, step_sizes,
                                          objective_at_initial_simplex,
                                          objective_at_initial_vertex,
                                          batch_evaluate_objective)
        domain_dtype = simplex.dtype
        (reflection, expansion, contraction,
         shrinkage) = _resolve_parameters(dim, reflection, expansion,
                                          contraction, shrinkage, domain_dtype)

        closure_kwargs = dict(
            objective_function=objective_function,
            dim=dim,
            func_tolerance=func_tolerance,
            position_tolerance=position_tolerance,
            batch_evaluate_objective=batch_evaluate_objective,
            reflection=reflection,
            expansion=expansion,
            contraction=contraction,
            shrinkage=shrinkage)

        def _loop_body(_, iterations, simplex, objective_at_simplex,
                       num_evaluations):
            (converged, next_simplex, next_objective,
             evaluations) = nelder_mead_one_step(simplex, objective_at_simplex,
                                                 **closure_kwargs)

            return (converged, iterations + 1, next_simplex, next_objective,
                    num_evaluations + evaluations)

        initial_args = (False, 0, simplex, objective_at_simplex,
                        num_evaluations)

        # Loop until either we have converged or if the max iterations are supplied
        # then until we have converged or exhausted the available iteration budget.
        def _is_converged(converged, num_iterations, *ignored_args):  # pylint:disable=unused-argument
            # It is important to ensure that not_converged is a tensor. If
            # converged is not a tensor but a Python bool, then the overloaded
            # op '~' acts as bitwise complement so ~True = -2 and ~False = -1.
            # In that case, the loop will never terminate.
            not_converged = tf.logical_not(converged)
            return (not_converged if max_iterations is None else
                    (not_converged & (num_iterations < max_iterations)))

        (converged, num_iterations, final_simplex, final_objective_values,
         final_evaluations) = tf.while_loop(
             cond=_is_converged,
             body=_loop_body,
             loop_vars=initial_args,
             parallel_iterations=parallel_iterations)
        order = tf.argsort(final_objective_values,
                           direction='ASCENDING',
                           stable=True)
        best_index = order[0]
        # The explicit cast to Tensor below is done to avoid returning a mixture
        # of Python types and Tensors which cause problems with session.run.
        # In the eager mode, converged may remain a Python bool. Trying to evaluate
        # the whole tuple in one evaluate call will raise an exception because
        # of the presence of non-tensors. This is very annoying so we explicitly
        # cast those arguments to Tensors.
        return NelderMeadOptimizerResults(
            converged=tf.convert_to_tensor(converged),
            num_objective_evaluations=final_evaluations,
            position=final_simplex[best_index],
            objective_value=final_objective_values[best_index],
            final_simplex=final_simplex,
            final_objective_values=final_objective_values,
            num_iterations=tf.convert_to_tensor(num_iterations),
            initial_simplex=simplex,
            initial_objective_values=objective_at_simplex)
  def _sample_paths(self,
                    times,
                    num_samples,
                    random_type,
                    skip,
                    seed,
                    normal_draws=None,
                    times_grid=None,
                    validate_args=False):
    """Returns a sample of paths from the process."""
    # Note: all the notations below are the same as in [1].
    num_requested_times = tf.shape(times)[0]
    params = [self._mean_reversion, self._volatility]
    if self._corr_matrix is not None:
      params = params + [self._corr_matrix]
    times, keep_mask = _prepare_grid(
        times, times_grid, *params)
    # Add zeros as a starting location
    dt = times[1:] - times[:-1]
    if dt.shape.is_fully_defined():
      steps_num = dt.shape.as_list()[-1]
    else:
      steps_num = tf.shape(dt)[-1]
      # TODO(b/148133811): Re-enable Sobol test when TF 2.2 is released.
      if random_type == random.RandomType.SOBOL:
        raise ValueError('Sobol sequence for Euler sampling is temporarily '
                         'unsupported when `time_step` or `times` have a '
                         'non-constant value')
    if normal_draws is None:
      # In order to use low-discrepancy random_type we need to generate the
      # sequence of independent random normals upfront. We also precompute
      # random numbers for stateless random type in order to ensure independent
      # samples for multiple function calls whith different seeds.
      if random_type in (random.RandomType.SOBOL,
                         random.RandomType.HALTON,
                         random.RandomType.HALTON_RANDOMIZED,
                         random.RandomType.STATELESS,
                         random.RandomType.STATELESS_ANTITHETIC):
        normal_draws = utils.generate_mc_normal_draws(
            num_normal_draws=self._dim, num_time_steps=steps_num,
            num_sample_paths=num_samples, random_type=random_type,
            seed=seed,
            dtype=self._dtype, skip=skip)
      else:
        normal_draws = None
    else:
      if validate_args:
        draws_times = tf.shape(normal_draws)[0]
        asserts = tf.assert_equal(
            draws_times, tf.shape(times)[0] - 1,  # We have added `0` to `times`
            message='`tf.shape(normal_draws)[1]` should be equal to the '
                    'number of all `times` plus the number of all jumps of '
                    'the piecewise constant parameters.')
        with tf.compat.v1.control_dependencies([asserts]):
          normal_draws = tf.identity(normal_draws)
    # The below is OK because we support exact discretization with piecewise
    # constant mr and vol.
    mean_reversion = self._mean_reversion(times)
    volatility = self._volatility(times)
    if self._corr_matrix is not None:
      corr_matrix = _get_parameters(
          times + tf.math.reduce_min(dt) / 2, self._corr_matrix)[0]
      corr_matrix_root = tf.linalg.cholesky(corr_matrix)
    else:
      corr_matrix_root = None

    exp_x_t = self._conditional_mean_x(times, mean_reversion, volatility)
    var_x_t = self._conditional_variance_x(times, mean_reversion, volatility)
    if self._dim == 1:
      mean_reversion = tf.expand_dims(mean_reversion, axis=0)

    cond_fn = lambda i, *args: i < tf.size(dt)
    def body_fn(i, written_count,
                current_x,
                rate_paths):
      """Simulate hull-white process to the next time point."""
      if normal_draws is None:
        normals = random.mv_normal_sample(
            (num_samples,),
            mean=tf.zeros((self._dim,), dtype=mean_reversion.dtype),
            random_type=random_type, seed=seed)
      else:
        normals = normal_draws[i]

      if corr_matrix_root is not None:
        normals = tf.linalg.matvec(corr_matrix_root[i], normals)
      vol_x_t = tf.math.sqrt(tf.nn.relu(tf.transpose(var_x_t)[i]))
      # If numerically `vol_x_t == 0`, the gradient of `vol_x_t` becomes `NaN`.
      # To prevent this, we explicitly set `vol_x_t` to zero tensor at zero
      # values so that the gradient is set to zero at this values.
      vol_x_t = tf.where(vol_x_t > 0.0, vol_x_t, 0.0)
      next_x = (tf.math.exp(-tf.transpose(mean_reversion)[i + 1] * dt[i])
                * current_x
                + tf.transpose(exp_x_t)[i]
                + vol_x_t * normals)
      f_0_t = self._instant_forward_rate_fn(times[i + 1])

      # Update `rate_paths`
      rate_paths = utils.maybe_update_along_axis(
          tensor=rate_paths,
          do_update=keep_mask[i + 1],
          ind=written_count,
          axis=1,
          new_tensor=tf.expand_dims(next_x, axis=1) + f_0_t)
      written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32)
      return (i + 1, written_count, next_x, rate_paths)

    rate_paths = tf.zeros((num_samples, num_requested_times, self._dim),
                          dtype=self._dtype)
    # Include initial state, if necessary
    f0_t = self._instant_forward_rate_fn(times[0])
    rate_paths = utils.maybe_update_along_axis(
        tensor=rate_paths,
        do_update=keep_mask[0],
        ind=0,
        axis=1,
        new_tensor=f0_t)
    written_count = tf.cast(keep_mask[0], dtype=tf.int32)
    initial_x = tf.zeros((num_samples, self._dim), dtype=self._dtype)
    # TODO(b/157232803): Use tf.cumsum instead?
    _, _, _, rate_paths = tf.while_loop(
        cond_fn, body_fn, (0, written_count, initial_x, rate_paths))

    return rate_paths
Пример #28
0
    def __call__(self,
                 momentum_parts,
                 state_parts,
                 target=None,
                 target_grad_parts=None,
                 name=None):
        """Applies `num_steps` of the leapfrog integrator.

    Args:
      momentum_parts: Python `list` of `Tensor`s representing momentume for each
        state part.
      state_parts: Python `list` of `Tensor`s which collectively representing
        the state.
      target: Batch of scalar `Tensor` representing the target (i.e.,
        unnormalized log prob) evaluated at `state_parts`.
      target_grad_parts: Python `list` of `Tensor`s representing the gradient of
        `target` with respect to each of `state_parts`.
      name: Python `str` used to group ops created by this function.

    Returns:
      next_momentum_parts: Python `list` of `Tensor`s representing new momentum.
      next_state_parts: Python `list` of `Tensor`s which collectively
        representing the new state.
      next_target: Batch of scalar `Tensor` representing the target (i.e.,
        unnormalized log prob) evaluated at `next_state_parts`.
      next_target_grad_parts: Python `list` of `Tensor`s representing the
        gradient of `next_target` with respect to each of `next_state_parts`.
    """
        with tf.name_scope(name or 'leapfrog_integrate'):
            [
                momentum_parts,
                state_parts,
                target,
                target_grad_parts,
            ] = process_args(self.target_fn, momentum_parts, state_parts,
                             target, target_grad_parts)

            # See Algorithm 1 of "Faster Hamiltonian Monte Carlo by Learning Leapfrog
            # Scale", https://arxiv.org/abs/1810.04449.

            half_next_momentum_parts = [
                v + tf.cast(0.5 * eps, v.dtype) * tf.cast(g, v.dtype)
                for v, eps, g in zip(momentum_parts, self.step_sizes,
                                     target_grad_parts)
            ]

            [
                _,
                next_half_next_momentum_parts,
                next_state_parts,
                next_target,
                next_target_grad_parts,
            ] = tf.while_loop(
                cond=lambda i, *_: i < self.num_steps,
                body=lambda i, *args: [i + 1] + list(
                    _one_step(  # pylint: disable=no-value-for-parameter,g-long-lambda
                        self.target_fn, self.step_sizes, *args)),
                loop_vars=[
                    tf.zeros_like(self.num_steps, name='iter'),
                    half_next_momentum_parts,
                    state_parts,
                    target,
                    target_grad_parts,
                ])

            next_momentum_parts = [
                v - tf.cast(0.5 * eps, v.dtype) * tf.cast(g, v.dtype)  # pylint: disable=g-complex-comprehension
                for v, eps, g in zip(next_half_next_momentum_parts,
                                     self.step_sizes, next_target_grad_parts)
            ]

            return (
                next_momentum_parts,
                next_state_parts,
                next_target,
                next_target_grad_parts,
            )
Пример #29
0
  def solve(self,
            ode_fn,
            initial_time,
            initial_state,
            solution_times,
            jacobian_fn=None,
            jacobian_sparsity=None,
            batch_ndims=None,
            previous_solver_internal_state=None):
    """See `tfp.math.ode.Solver.solve`."""

    # The `solve` function is comprised of the following sequential stages:
    # (1) Make static assertions.
    # (2) Initialize variables.
    # (3) Make non-static assertions.
    # (4) Solve up to final time.
    # (5) Return `Results` object.
    #
    # The stages can be found in the code by searching for (n) where n=1..5.
    #
    # By static vs. non-static assertions (see stages 1 and 3), we mean
    # assertions that can be made before the graph is run vs. those that can
    # only be made at run time. The latter are constructed as a list of
    # tf.Assert operations by the function `assert_ops` (see below).
    #
    # If `solution_times` is specified as a `Tensor`, stage 4 consists of three
    # nested loops, which can be conceptually understood as follows:
    # ```
    # current_time, current_state = initial_time, initial_state
    # order, step_size = 1, first_step_size
    # for solution_time in solution_times:
    #   while current_time < solution_time:
    #     while True:
    #       next_time = current_time + step_size
    #       next_state, error = (
    #           solve_nonlinear_equation_to_get_approximate_state_at_next_time(
    #           current_time, current_state, next_time, order))
    #       if error < tolerance:
    #         current_time, current_state = next_time, next_state
    #         order, step_size = (
    #           maybe_update_order_and_step_size(order, step_size))
    #         break
    #       else:
    #         step_size = decrease_step_size(step_size)
    # ```
    # The outermost loop advances the solver to the next `solution_time` (see
    # `advance_to_solution_time`). The middle loop advances the solver by a
    # small timestep (see `step`). The innermost loop determines the size of
    # that timestep (see `maybe_step`).
    #
    # If `solution_times` is specified as
    # `tfp.math.ode.ChosenBySolver(final_time)`, the outermost loop is skipped
    # and `solution_time` in the middle loop is replaced by `final_time`.

    def assert_ops():
      """Creates a list of assert operations."""
      if not self._validate_args:
        return []
      assert_ops = []
      if ((not initial_state_missing) and
          (previous_solver_internal_state is not None)):
        assert_initial_state_matches_previous_solver_internal_state = (
            tf.assert_near(
                tf.norm(
                    original_initial_state -
                    previous_solver_internal_state.backward_differences[0],
                    np.inf),
                0.,
                message='`previous_solver_internal_state` does not match '
                '`initial_state`.'))
        assert_ops.append(
            assert_initial_state_matches_previous_solver_internal_state)
      if solution_times_chosen_by_solver:
        assert_ops.append(
            util.assert_positive(final_time - initial_time,
                                 'final_time - initial_time'))
      else:
        assert_ops += [
            util.assert_increasing(solution_times, 'solution_times'),
            util.assert_nonnegative(solution_times[0] - initial_time,
                                    'solution_times[0] - initial_time'),
        ]
      if max_num_steps is not None:
        assert_ops.append(util.assert_positive(max_num_steps, 'max_num_steps'))
      if max_num_newton_iters is not None:
        assert_ops.append(
            util.assert_positive(max_num_newton_iters, 'max_num_newton_iters'))
      assert_ops += [
          util.assert_positive(rtol, 'rtol'),
          util.assert_positive(atol, 'atol'),
          util.assert_positive(first_step_size, 'first_step_size'),
          util.assert_positive(safety_factor, 'safety_factor'),
          util.assert_positive(min_step_size_factor, 'min_step_size_factor'),
          util.assert_positive(max_step_size_factor, 'max_step_size_factor'),
          tf.Assert((max_order >= 1) & (max_order <= bdf_util.MAX_ORDER), [
              '`max_order` must be between 1 and {}.'.format(bdf_util.MAX_ORDER)
          ]),
          util.assert_positive(newton_tol_factor, 'newton_tol_factor'),
          util.assert_positive(newton_step_size_factor,
                               'newton_step_size_factor'),
      ]
      return assert_ops

    def advance_to_solution_time(n, diagnostics, iterand, solver_internal_state,
                                 states_array, times_array):
      """Takes multiple steps to advance time to `solution_times[n]`."""

      def step_cond(next_time, diagnostics, iterand, *_):
        return (iterand.time < next_time) & (tf.equal(diagnostics.status, 0))

      solution_times_n = solution_times_array.read(n)
      [
          _, diagnostics, iterand, solver_internal_state, states_array,
          times_array
      ] = tf.while_loop(step_cond, step, [
          solution_times_n, diagnostics, iterand, solver_internal_state,
          states_array, times_array
      ])
      states_array = states_array.write(
          n, solver_internal_state.backward_differences[0])
      times_array = times_array.write(n, solution_times_n)
      return (n + 1, diagnostics, iterand, solver_internal_state, states_array,
              times_array)

    def step(next_time, diagnostics, iterand, solver_internal_state,
             states_array, times_array):
      """Takes a single step."""
      distance_to_next_time = next_time - iterand.time
      overstepped = iterand.new_step_size > distance_to_next_time
      iterand = iterand._replace(
          new_step_size=tf1.where(overstepped, distance_to_next_time,
                                  iterand.new_step_size),
          should_update_step_size=overstepped | iterand.should_update_step_size)

      if not self._evaluate_jacobian_lazily:
        diagnostics = diagnostics._replace(
            num_jacobian_evaluations=diagnostics.num_jacobian_evaluations + 1)
        iterand = iterand._replace(
            jacobian=jacobian_fn_mat(
                iterand.time, solver_internal_state.backward_differences[0]),
            jacobian_is_up_to_date=True)

      def maybe_step_cond(accepted, diagnostics, *_):
        return tf.logical_not(accepted) & tf.equal(diagnostics.status, 0)

      _, diagnostics, iterand, solver_internal_state = tf.while_loop(
          maybe_step_cond, maybe_step,
          [False, diagnostics, iterand, solver_internal_state])

      if solution_times_chosen_by_solver:
        states_array = states_array.write(
            states_array.size(), solver_internal_state.backward_differences[0])
        times_array = times_array.write(times_array.size(), iterand.time)

      return (next_time, diagnostics, iterand, solver_internal_state,
              states_array, times_array)

    def maybe_step(accepted, diagnostics, iterand, solver_internal_state):
      """Takes a single step only if the outcome has a low enough error."""
      [
          num_jacobian_evaluations, num_matrix_factorizations,
          num_ode_fn_evaluations, status
      ] = diagnostics
      [
          jacobian, jacobian_is_up_to_date, new_step_size, num_steps,
          num_steps_same_size, should_update_jacobian, should_update_step_size,
          time, unitary, upper
      ] = iterand
      backward_differences, order, state_shape, step_size = solver_internal_state

      if max_num_steps is not None:
        status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0)

      backward_differences = tf1.where(
          should_update_step_size,
          bdf_util.interpolate_backward_differences(backward_differences, order,
                                                    new_step_size / step_size),
          backward_differences)
      step_size = tf1.where(should_update_step_size, new_step_size, step_size)
      should_update_factorization = should_update_step_size
      num_steps_same_size = tf1.where(should_update_step_size, 0,
                                      num_steps_same_size)

      def update_factorization():
        return bdf_util.newton_qr(jacobian,
                                  newton_coefficients_array.read(order),
                                  step_size)

      if self._evaluate_jacobian_lazily:

        def update_jacobian_and_factorization():
          new_jacobian = jacobian_fn_mat(time, backward_differences[0])
          new_unitary, new_upper = update_factorization()
          return [
              new_jacobian, True, num_jacobian_evaluations + 1, new_unitary,
              new_upper
          ]

        def maybe_update_factorization():
          new_unitary, new_upper = tf.cond(
              should_update_factorization,
              update_factorization, lambda: [unitary, upper])
          return [
              jacobian, jacobian_is_up_to_date, num_jacobian_evaluations,
              new_unitary, new_upper
          ]

        [
            jacobian, jacobian_is_up_to_date, num_jacobian_evaluations, unitary,
            upper
        ] = tf.cond(should_update_jacobian, update_jacobian_and_factorization,
                    maybe_update_factorization)
      else:
        unitary, upper = update_factorization()
        num_matrix_factorizations += 1

      tol = atol + rtol * tf.abs(backward_differences[0])
      newton_tol = newton_tol_factor * tf.norm(tol)

      [
          newton_converged, next_backward_difference, next_state,
          newton_num_iters
      ] = bdf_util.newton(backward_differences, max_num_newton_iters,
                          newton_coefficients_array.read(order), ode_fn_vec,
                          order, step_size, time, newton_tol, unitary, upper)
      num_steps += 1
      num_ode_fn_evaluations += newton_num_iters

      # If Newton's method failed and the Jacobian was up to date, decrease the
      # step size.
      newton_failed = tf.logical_not(newton_converged)
      should_update_step_size = newton_failed & jacobian_is_up_to_date
      new_step_size = step_size * tf1.where(should_update_step_size,
                                            newton_step_size_factor, 1.)

      # If Newton's method failed and the Jacobian was NOT up to date, update
      # the Jacobian.
      should_update_jacobian = newton_failed & tf.logical_not(
          jacobian_is_up_to_date)

      error_ratio = tf1.where(
          newton_converged,
          bdf_util.error_ratio(next_backward_difference,
                               error_coefficients_array.read(order), tol),
          np.nan)
      accepted = error_ratio < 1.
      converged_and_rejected = newton_converged & tf.logical_not(accepted)

      # If Newton's method converged but the solution was NOT accepted, decrease
      # the step size.
      new_step_size = tf1.where(
          converged_and_rejected,
          util.next_step_size(step_size, order, error_ratio, safety_factor,
                              min_step_size_factor, max_step_size_factor),
          new_step_size)
      should_update_step_size = should_update_step_size | converged_and_rejected

      # If Newton's method converged and the solution was accepted, update the
      # matrix of backward differences.
      time = tf1.where(accepted, time + step_size, time)
      backward_differences = tf1.where(
          accepted,
          bdf_util.update_backward_differences(backward_differences,
                                               next_backward_difference,
                                               next_state, order),
          backward_differences)
      jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not(accepted)
      num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1,
                                      num_steps_same_size)

      # Order and step size are only updated if we have taken strictly more than
      # order + 1 steps of the same size. This is to prevent the order from
      # being throttled.
      should_update_order_and_step_size = accepted & (
          num_steps_same_size > order + 1)

      backward_differences_array = tf.TensorArray(
          backward_differences.dtype,
          size=bdf_util.MAX_ORDER + 3,
          clear_after_read=False,
          element_shape=next_backward_difference.get_shape()).unstack(
              backward_differences)
      new_order = order
      new_error_ratio = error_ratio
      for offset in [-1, +1]:
        proposed_order = tf.clip_by_value(order + offset, 1, max_order)
        proposed_error_ratio = bdf_util.error_ratio(
            backward_differences_array.read(proposed_order + 1),
            error_coefficients_array.read(proposed_order), tol)
        proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio
        new_order = tf1.where(
            should_update_order_and_step_size & proposed_error_ratio_is_lower,
            proposed_order, new_order)
        new_error_ratio = tf1.where(
            should_update_order_and_step_size & proposed_error_ratio_is_lower,
            proposed_error_ratio, new_error_ratio)
      order = new_order
      error_ratio = new_error_ratio

      new_step_size = tf1.where(
          should_update_order_and_step_size,
          util.next_step_size(step_size, order, error_ratio, safety_factor,
                              min_step_size_factor, max_step_size_factor),
          new_step_size)
      should_update_step_size = (
          should_update_step_size | should_update_order_and_step_size)

      diagnostics = _BDFDiagnostics(num_jacobian_evaluations,
                                    num_matrix_factorizations,
                                    num_ode_fn_evaluations, status)
      iterand = _BDFIterand(jacobian, jacobian_is_up_to_date, new_step_size,
                            num_steps, num_steps_same_size,
                            should_update_jacobian, should_update_step_size,
                            time, unitary, upper)
      solver_internal_state = _BDFSolverInternalState(backward_differences,
                                                      order, state_shape,
                                                      step_size)
      return accepted, diagnostics, iterand, solver_internal_state

    # (1) Make static assertions.
    # TODO(parsiad): Support specifying Jacobian sparsity patterns.
    if jacobian_sparsity is not None:
      raise NotImplementedError('The BDF solver does not support specifying '
                                'Jacobian sparsity patterns.')
    if batch_ndims is not None and batch_ndims != 0:
      raise NotImplementedError('The BDF solver does not support batching.')
    solution_times_chosen_by_solver = (
        isinstance(solution_times, base.ChosenBySolver))
    initial_state_missing = initial_state is None
    if initial_state_missing and previous_solver_internal_state is None:
      raise ValueError(
          'At least one of `initial_state` or `previous_solver_internal_state` '
          'must be specified')

    with tf.name_scope(self._name):

      # (2) Initialize variables.
      original_initial_state = initial_state
      if previous_solver_internal_state is None:
        initial_state = tf.convert_to_tensor(initial_state)
        original_state_shape = tf.shape(initial_state)
      else:
        initial_state = previous_solver_internal_state.backward_differences[0]
        original_state_shape = previous_solver_internal_state.state_shape
      state_dtype = initial_state.dtype
      util.error_if_not_real_or_complex(initial_state, 'initial_state')
      # TODO(parsiad): Support complex automatic Jacobians.
      if jacobian_fn is None and state_dtype.is_complex:
        raise NotImplementedError('The BDF solver does not support automatic '
                                  'Jacobian computations for complex dtypes.')
      num_odes = tf.size(initial_state)
      original_state_tensor_shape = initial_state.get_shape()
      initial_state = tf.reshape(initial_state, [-1])
      ode_fn_vec = util.get_ode_fn_vec(ode_fn, original_state_shape)
      # `real_dtype` is the floating point `dtype` associated with
      # `initial_state.dtype` (recall that the latter can be complex).
      real_dtype = tf.abs(initial_state).dtype
      initial_time = tf.ensure_shape(
          tf.convert_to_tensor(initial_time, dtype=real_dtype), [])
      num_solution_times = 0
      if solution_times_chosen_by_solver:
        final_time = solution_times.final_time
        final_time = tf.ensure_shape(
            tf.convert_to_tensor(final_time, dtype=real_dtype), [])
      else:
        solution_times = tf.convert_to_tensor(solution_times, dtype=real_dtype)
        num_solution_times = tf.size(solution_times)
        solution_times_array = tf.TensorArray(
            solution_times.dtype, size=num_solution_times,
            element_shape=[]).unstack(solution_times)
        util.error_if_not_vector(solution_times, 'solution_times')
      jacobian_fn_mat = util.get_jacobian_fn_mat(
          jacobian_fn,
          ode_fn_vec,
          original_state_shape,
          use_pfor=self._use_pfor_to_compute_jacobian)
      rtol = tf.convert_to_tensor(self._rtol, dtype=real_dtype)
      atol = tf.convert_to_tensor(self._atol, dtype=real_dtype)
      safety_factor = tf.ensure_shape(
          tf.convert_to_tensor(self._safety_factor, dtype=real_dtype), [])
      min_step_size_factor = tf.ensure_shape(
          tf.convert_to_tensor(self._min_step_size_factor, dtype=real_dtype),
          [])
      max_step_size_factor = tf.ensure_shape(
          tf.convert_to_tensor(self._max_step_size_factor, dtype=real_dtype),
          [])
      max_num_steps = self._max_num_steps
      if max_num_steps is not None:
        max_num_steps = tf.convert_to_tensor(max_num_steps, dtype=tf.int32)
      max_order = tf.convert_to_tensor(self._max_order, dtype=tf.int32)
      max_num_newton_iters = self._max_num_newton_iters
      if max_num_newton_iters is not None:
        max_num_newton_iters = tf.convert_to_tensor(
            max_num_newton_iters, dtype=tf.int32)
      newton_tol_factor = tf.ensure_shape(
          tf.convert_to_tensor(self._newton_tol_factor, dtype=real_dtype), [])
      newton_step_size_factor = tf.ensure_shape(
          tf.convert_to_tensor(self._newton_step_size_factor, dtype=real_dtype),
          [])
      bdf_coefficients = tf.cast(
          tf.concat(
              [[0.],
               tf.convert_to_tensor(self._bdf_coefficients, dtype=real_dtype)],
              0), state_dtype)
      util.error_if_not_vector(bdf_coefficients, 'bdf_coefficients')
      newton_coefficients = 1. / (
          (1. - bdf_coefficients) * bdf_util.RECIPROCAL_SUMS)
      newton_coefficients_array = tf.TensorArray(
          newton_coefficients.dtype,
          size=bdf_util.MAX_ORDER + 1,
          clear_after_read=False,
          element_shape=[]).unstack(newton_coefficients)
      error_coefficients = bdf_coefficients * bdf_util.RECIPROCAL_SUMS + 1. / (
          bdf_util.ORDERS + 1)
      error_coefficients_array = tf.TensorArray(
          error_coefficients.dtype,
          size=bdf_util.MAX_ORDER + 1,
          clear_after_read=False,
          element_shape=[]).unstack(error_coefficients)
      first_step_size = self._first_step_size
      if first_step_size is None:
        first_step_size = bdf_util.first_step_size(
            atol, error_coefficients_array.read(1), initial_state, initial_time,
            ode_fn_vec, rtol, safety_factor)
      elif previous_solver_internal_state is not None:
        tf.logging.warn('`first_step_size` is ignored since'
                        '`previous_solver_internal_state` was specified.')
      first_step_size = tf.convert_to_tensor(first_step_size, dtype=real_dtype)
      if self._validate_args:
        if max_num_steps is not None:
          max_num_steps = tf.ensure_shape(max_num_steps, [])
        max_order = tf.ensure_shape(max_order, [])
        if max_num_newton_iters is not None:
          max_num_newton_iters = tf.ensure_shape(max_num_newton_iters, [])
        bdf_coefficients = tf.ensure_shape(bdf_coefficients, [6])
        first_step_size = tf.ensure_shape(first_step_size, [])
      solver_internal_state = previous_solver_internal_state
      if solver_internal_state is None:
        first_order_backward_difference = ode_fn_vec(
            initial_time, initial_state) * tf.cast(first_step_size, state_dtype)
        backward_differences = tf.concat([
            tf.reshape(initial_state, [1, -1]),
            first_order_backward_difference[tf.newaxis, :],
            tf.zeros(
                tf.stack([bdf_util.MAX_ORDER + 1, num_odes]),
                dtype=state_dtype),
        ], 0)
        solver_internal_state = _BDFSolverInternalState(
            backward_differences=backward_differences,
            order=1,
            state_shape=original_state_shape,
            step_size=first_step_size)
      states_array = tf.TensorArray(
          state_dtype,
          size=num_solution_times,
          dynamic_size=solution_times_chosen_by_solver,
          element_shape=initial_state.get_shape())
      times_array = tf.TensorArray(
          real_dtype,
          size=num_solution_times,
          dynamic_size=solution_times_chosen_by_solver,
          element_shape=tf.TensorShape([]))
      diagnostics = _BDFDiagnostics(
          num_jacobian_evaluations=0,
          num_matrix_factorizations=0,
          num_ode_fn_evaluations=0,
          status=0)
      iterand = _BDFIterand(
          jacobian=tf.zeros([num_odes, num_odes], dtype=state_dtype),
          jacobian_is_up_to_date=False,
          new_step_size=solver_internal_state.step_size,
          num_steps=0,
          num_steps_same_size=0,
          should_update_jacobian=True,
          should_update_step_size=False,
          time=initial_time,
          unitary=tf.zeros([num_odes, num_odes], dtype=state_dtype),
          upper=tf.zeros([num_odes, num_odes], dtype=state_dtype))

      # (3) Make non-static assertions.
      with tf.control_dependencies(assert_ops()):

        # (4) Solve up to final time.
        if solution_times_chosen_by_solver:

          def step_cond(next_time, diagnostics, iterand, *_):
            return (iterand.time < next_time) & (
                tf.equal(diagnostics.status, 0))

          [
              _, diagnostics, iterand, solver_internal_state, states_array,
              times_array
          ] = tf.while_loop(step_cond, step, [
              final_time, diagnostics, iterand, solver_internal_state,
              states_array, times_array
          ])

        else:

          def advance_to_solution_time_cond(n, diagnostics, *_):
            return (n < num_solution_times) & (tf.equal(diagnostics.status, 0))

          [
              _, diagnostics, iterand, solver_internal_state, states_array,
              times_array
          ] = tf.while_loop(advance_to_solution_time_cond,
                            advance_to_solution_time, [
                                0, diagnostics, iterand, solver_internal_state,
                                states_array, times_array
                            ])

        # (6) Return `Results` object.
        states = tf.reshape(states_array.stack(),
                            tf.concat([[-1], original_state_shape], 0))
        times = times_array.stack()
        if not solution_times_chosen_by_solver:
          times.set_shape(solution_times.get_shape())
          states.set_shape(solution_times.get_shape().concatenate(
              original_state_tensor_shape))
        return base.Results(
            times=times,
            states=states,
            diagnostics=diagnostics,
            solver_internal_state=solver_internal_state)
Пример #30
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
        start_trajectory_seed, loop_seed = samplers.split_seed(seed)

        with tf.name_scope(self.name + '.one_step'):
            unwrap_state_list = not tf.nest.is_nested(current_state)
            if unwrap_state_list:
                current_state = [current_state]

            current_target_log_prob = previous_kernel_results.target_log_prob
            [init_momentum, init_energy, log_slice_sample
             ] = self._start_trajectory_batched(current_state,
                                                current_target_log_prob,
                                                seed=start_trajectory_seed)

            def _copy(v):
                return v * ps.ones(ps.pad(
                    [2], paddings=[[0, ps.rank(v)]], constant_values=1),
                                   dtype=v.dtype)

            initial_state = TreeDoublingState(
                momentum=init_momentum,
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.grads_target_log_prob
            )
            initial_step_state = tf.nest.map_structure(_copy, initial_state)

            if MULTINOMIAL_SAMPLE:
                init_weight = tf.zeros_like(init_energy)  # log(exp(H0 - H0))
            else:
                init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE)

            candidate_state = TreeDoublingStateCandidate(
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.
                grads_target_log_prob,
                energy=init_energy,
                weight=init_weight)

            initial_step_metastate = TreeDoublingMetaState(
                candidate_state=candidate_state,
                is_accepted=tf.zeros_like(init_energy, dtype=tf.bool),
                momentum_sum=init_momentum,
                energy_diff_sum=tf.zeros_like(init_energy),
                leapfrog_count=tf.zeros_like(init_energy,
                                             dtype=TREE_COUNT_DTYPE),
                continue_tree=tf.ones_like(init_energy, dtype=tf.bool),
                not_divergence=tf.ones_like(init_energy, dtype=tf.bool))

            # Convert the write/read instruction into TensorArray so that it is
            # compatible with XLA.
            write_instruction = tf.TensorArray(
                TREE_COUNT_DTYPE,
                size=len(self._write_instruction),
                clear_after_read=False).unstack(self._write_instruction)
            read_instruction = tf.TensorArray(tf.int32,
                                              size=len(self._read_instruction),
                                              clear_after_read=False).unstack(
                                                  self._read_instruction)

            current_step_meta_info = OneStepMetaInfo(
                log_slice_sample=log_slice_sample,
                init_energy=init_energy,
                write_instruction=write_instruction,
                read_instruction=read_instruction)

            _, _, _, new_step_metastate = tf.while_loop(
                cond=lambda iter_, seed, state, metastate: (  # pylint: disable=g-long-lambda
                    (iter_ < self.max_tree_depth) & tf.reduce_any(
                        metastate.continue_tree)),
                body=lambda iter_, seed, state, metastate: self.
                _loop_tree_doubling(  # pylint: disable=g-long-lambda
                    previous_kernel_results.step_size, previous_kernel_results.
                    momentum_state_memory, current_step_meta_info, iter_,
                    state, metastate, seed),
                loop_vars=(tf.zeros([], dtype=tf.int32,
                                    name='iter'), loop_seed,
                           initial_step_state, initial_step_metastate),
                parallel_iterations=self.parallel_iterations,
            )

            kernel_results = NUTSKernelResults(
                target_log_prob=new_step_metastate.candidate_state.target,
                grads_target_log_prob=(
                    new_step_metastate.candidate_state.target_grad_parts),
                momentum_state_memory=previous_kernel_results.
                momentum_state_memory,
                step_size=previous_kernel_results.step_size,
                log_accept_ratio=tf.math.log(
                    new_step_metastate.energy_diff_sum /
                    tf.cast(new_step_metastate.leapfrog_count,
                            dtype=new_step_metastate.energy_diff_sum.dtype)),
                leapfrogs_taken=(new_step_metastate.leapfrog_count *
                                 self.unrolled_leapfrog_steps),
                is_accepted=new_step_metastate.is_accepted,
                reach_max_depth=new_step_metastate.continue_tree,
                has_divergence=~new_step_metastate.not_divergence,
                energy=new_step_metastate.candidate_state.energy,
                seed=seed,
            )

            result_state = new_step_metastate.candidate_state.state
            if unwrap_state_list:
                result_state = result_state[0]

            return result_state, kernel_results