Beispiel #1
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        # The code below propagates one step states of shape
        #  [n_replica] + batch_shape + event_shape.
        #
        # The step is done in three parts:
        #  1) Call one_step to transition states via a tempered version of
        #     self.target_log_prob_fn (see _replica_target_log_prob).
        #  2) Permute values in states
        #  3) Update state-dependent values, such as log_probs.
        #
        # We chose to swap states, rather than temperatures, because...
        # (i)  If swapping temperatures, you *still* have to swap log_probs to
        #      determine acceptance, as well as states (for kernel results).
        #      So it's just as difficult to swap temperatures.
        # (ii) If swapping temperatures, you have to take care to swap any user-
        #      supplied temperature related things (like step size).
        #      A-priori, we don't know what else will need to be swapped!
        # (iii)In both cases, the kernel results need to be updated in a non-trivial
        #      manner....so we either special-case, or use bootstrap.

        with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.inverse_temperatures,
                name='inverse_temperatures')

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                target_log_prob_fn=self.target_log_prob_fn,
                inverse_temperatures=inverse_temperatures,
                untempered_log_prob_fn=self.untempered_log_prob_fn,
                tempered_log_prob_fn=self.tempered_log_prob_fn,
            )
            # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                raise TypeError(
                    '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` '
                    'argument. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')

            seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
            inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3)
            # Step the inner TransitionKernel.
            [
                pre_swap_replica_states,
                pre_swap_replica_results,
            ] = inner_kernel.one_step(
                previous_kernel_results.post_swap_replica_states,
                previous_kernel_results.post_swap_replica_results,
                seed=inner_seed)

            pre_swap_replica_target_log_prob = _get_field(
                # These are tempered log probs (have been divided by temperature).
                pre_swap_replica_results,
                'target_log_prob')

            dtype = pre_swap_replica_target_log_prob.dtype
            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]
            replica_and_batch_rank = ps.rank(pre_swap_replica_target_log_prob)
            num_replica = ps.size0(inverse_temperatures)

            inverse_temperatures = bu.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Now that each replica has done one_step, it is time to consider swaps.

            # swap.shape = [n_replica], and is a "once only" permutation, meaning it
            # is achievable by a sequence of pairwise permutations, where each element
            # is moved at most once.
            # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
            # 1, keeping 2 fixed.  This exact same swap is considered for *every*
            # batch member.  Of course some batch members may accept and some reject.
            try:
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed,
                        step_count=previous_kernel_results.step_count),
                    dtype=tf.int32)
            except TypeError as e:
                if 'step_count' not in str(e):
                    raise
                warnings.warn(
                    'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
                    'the `step_count` argument. Falling back to omitting the '
                    'argument. This fallback will be removed after 24-Oct-2020.'
                )
                swaps = tf.cast(
                    self.swap_proposal_fn(  # pylint: disable=not-callable
                        num_replica,
                        batch_shape=batch_shape,
                        seed=swap_seed),
                    dtype=tf.int32)

            null_swaps = bu.left_justified_expand_dims_like(
                tf.range(num_replica, dtype=swaps.dtype), swaps)
            swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                                  self.validate_args)

            # Un-temper the log probs for use in the swap acceptance ratio.
            if self.tempered_log_prob_fn is None:
                # Efficient way of re-evaluating target_log_prob_fn on the
                # pre_swap_replica_states.
                untempered_negative_energy_ignoring_ulp = (
                    # Since untempered_log_prob_fn is None, we may assume
                    # inverse_temperatures > 0 (else the target is improper).
                    pre_swap_replica_target_log_prob / inverse_temperatures)
            else:
                # The untempered_log_prob_fn does not factor into the acceptance ratio.
                # Proof: Suppose the tempered target is
                #   p_k(x) = f(x)^{beta_k} g(x),
                # So f(x) is tempered, and g(x) is not.  Then, the acceptance ratio for
                # a 1 <--> 2 swap is...
                #   (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2))
                # which depends only on f(x), since terms involving g(x) cancel.
                untempered_negative_energy_ignoring_ulp = self.tempered_log_prob_fn(
                    *pre_swap_replica_states)

            # Since `swaps` is its own inverse permutation we automatically know the
            # swap counterpart: range(num_replica). We use this idea to compute the
            # acceptance in a vectorized manner at the cost of wasting roughly half
            # our computation. Although we could use `unique` to solve this problem,
            # we expect the cost of `unique` to be higher than the dozens of wasted
            # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
            # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

            # Note: diffs would normally be "proposed - current" however energy is
            # flipped since `energy == -log_prob`.
            # Note: The untempered_log_prob_fn (if provided) is not included in
            # untempered_pre_swap_replica_target_log_prob, and hence does not factor
            # into energy_diff. Why? Because, it cancels out in the acceptance ratio.
            energy_diff = (untempered_negative_energy_ignoring_ulp -
                           mcmc_util.index_remapping_gather(
                               untempered_negative_energy_ignoring_ulp,
                               swaps,
                               name='gather_swap_tlp'))
            swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
                inverse_temperatures, swaps, name='gather_swap_temps')
            inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

            # If i and j are swapping, log_accept_ratio[] i and j are equal.
            log_accept_ratio = (energy_diff * bu.left_justified_expand_dims_to(
                inverse_temp_diff, replica_and_batch_rank))

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=replica_and_batch_shape,
                                 dtype=dtype,
                                 seed=logu_seed))
            anchor_swaps = tf.minimum(swaps, null_swaps)
            log_uniform = mcmc_util.index_remapping_gather(
                log_uniform, anchor_swaps)

            is_swap_accepted_mask = tf.less(log_uniform,
                                            log_accept_ratio,
                                            name='is_swap_accepted_mask')

            def _swap_tensor(x):
                return mcmc_util.choose(
                    is_swap_accepted_mask,
                    mcmc_util.index_remapping_gather(x, swaps), x)

            post_swap_replica_states = [
                _swap_tensor(s) for s in pre_swap_replica_states
            ]

            expanded_null_swaps = bu.left_justified_broadcast_to(
                null_swaps, replica_and_batch_shape)
            is_swap_proposed = _compute_swap_notmatrix(
                # Broadcast both so they have shape [num_replica] + batch_shape.
                # This (i) makes them have same shape as is_swap_accepted, and
                # (ii) keeps shape consistent if someday swaps has a batch shape.
                expanded_null_swaps,
                bu.left_justified_broadcast_to(swaps, replica_and_batch_shape))

            # To get is_swap_accepted in ordered position, we use
            # _compute_swap_notmatrix on current and next replica positions.
            post_swap_replica_position = _swap_tensor(expanded_null_swaps)

            is_swap_accepted = _compute_swap_notmatrix(
                post_swap_replica_position, expanded_null_swaps)

            if self._state_includes_replicas:
                post_swap_states = post_swap_replica_states
            else:
                post_swap_states = [s[0] for s in post_swap_replica_states]

            post_swap_replica_results = _set_swapped_fields_to_nan(
                _swap_log_prob_and_maybe_grads(pre_swap_replica_results,
                                               post_swap_replica_states,
                                               inner_kernel))

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = post_swap_states
            else:
                states = post_swap_states[0]

            post_swap_kernel_results = ReplicaExchangeMCKernelResults(
                post_swap_replica_states=post_swap_replica_states,
                pre_swap_replica_results=pre_swap_replica_results,
                post_swap_replica_results=post_swap_replica_results,
                is_swap_proposed=is_swap_proposed,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                # Store the original pkr.inverse_temperatures in case its a
                # `tf.Variable`.
                inverse_temperatures=previous_kernel_results.
                inverse_temperatures,
                swaps=swaps,
                step_count=previous_kernel_results.step_count + 1,
                seed=seed,
                potential_energy=-untempered_negative_energy_ignoring_ulp,
            )

            return states, post_swap_kernel_results
Beispiel #2
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(self.name + '.one_step'):
            unwrap_state_list = not tf.nest.is_nested(current_state)
            if unwrap_state_list:
                current_state = [current_state]

            current_target_log_prob = previous_kernel_results.target_log_prob
            [init_momentum, init_energy, log_slice_sample
             ] = self._start_trajectory_batched(current_state,
                                                current_target_log_prob)

            def _copy(v):
                return v * prefer_static.ones(prefer_static.pad(
                    [2],
                    paddings=[[0, prefer_static.rank(v)]],
                    constant_values=1),
                                              dtype=v.dtype)

            initial_state = TreeDoublingState(
                momentum=init_momentum,
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.grads_target_log_prob
            )
            initial_step_state = tf.nest.map_structure(_copy, initial_state)

            if MULTINOMIAL_SAMPLE:
                init_weight = tf.zeros_like(init_energy)
            else:
                init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE)

            candidate_state = TreeDoublingStateCandidate(
                state=current_state,
                target=current_target_log_prob,
                target_grad_parts=previous_kernel_results.
                grads_target_log_prob,
                energy=init_energy,
                weight=init_weight)

            initial_step_metastate = TreeDoublingMetaState(
                candidate_state=candidate_state,
                is_accepted=tf.zeros_like(init_energy, dtype=tf.bool),
                energy_diff_sum=tf.zeros_like(init_energy),
                leapfrog_count=tf.zeros_like(init_energy,
                                             dtype=TREE_COUNT_DTYPE),
                continue_tree=tf.ones_like(init_energy, dtype=tf.bool),
                not_divergence=tf.ones_like(init_energy, dtype=tf.bool))

            # Convert the write/read instruction into TensorArray so that it is
            # compatible with XLA.
            write_instruction = tf.TensorArray(
                TREE_COUNT_DTYPE,
                size=2**(self.max_tree_depth - 1),
                clear_after_read=False).unstack(self._write_instruction)
            read_instruction = tf.TensorArray(
                tf.int32,
                size=2**(self.max_tree_depth - 1),
                clear_after_read=False).unstack(self._read_instruction)

            current_step_meta_info = OneStepMetaInfo(
                log_slice_sample=log_slice_sample,
                init_energy=init_energy,
                write_instruction=write_instruction,
                read_instruction=read_instruction)

            _, _, new_step_metastate = tf.while_loop(
                cond=lambda iter_, state, metastate: (  # pylint: disable=g-long-lambda
                    ((iter_ < self.max_tree_depth) & tf.reduce_any(
                        metastate.continue_tree))),
                body=lambda iter_, state, metastate: self.loop_tree_doubling(  # pylint: disable=g-long-lambda
                    previous_kernel_results.step_size, previous_kernel_results.
                    momentum_state_memory, current_step_meta_info, iter_,
                    state, metastate),
                loop_vars=(tf.zeros([], dtype=tf.int32, name='iter'),
                           initial_step_state, initial_step_metastate),
                parallel_iterations=TF_WHILE_PARALLEL_ITERATIONS,
            )

            kernel_results = NUTSKernelResults(
                target_log_prob=new_step_metastate.candidate_state.target,
                grads_target_log_prob=(
                    new_step_metastate.candidate_state.target_grad_parts),
                momentum_state_memory=previous_kernel_results.
                momentum_state_memory,
                step_size=previous_kernel_results.step_size,
                log_accept_ratio=tf.math.log(
                    new_step_metastate.energy_diff_sum /
                    tf.cast(new_step_metastate.leapfrog_count,
                            dtype=new_step_metastate.energy_diff_sum.dtype)),
                # TODO(junpenglao): return non-cumulated leapfrogs_taken once
                # benchmarking is done.
                leapfrogs_taken=(previous_kernel_results.leapfrogs_taken +
                                 new_step_metastate.leapfrog_count *
                                 self.unrolled_leapfrog_steps),
                is_accepted=new_step_metastate.is_accepted,
                reach_max_depth=new_step_metastate.continue_tree,
                has_divergence=~new_step_metastate.not_divergence,
                energy=new_step_metastate.candidate_state.energy)

            result_state = new_step_metastate.candidate_state.state
            if unwrap_state_list:
                result_state = result_state[0]

            return result_state, kernel_results
Beispiel #3
0
    def _build_sub_tree(self,
                        directions,
                        integrator,
                        current_step_meta_info,
                        nsteps,
                        initial_state,
                        continue_tree,
                        not_divergence,
                        momentum_state_memory,
                        name=None):
        with tf.name_scope('build_sub_tree'):
            batch_shape = prefer_static.shape(
                current_step_meta_info.init_energy)
            # We never want to select the inital state
            if MULTINOMIAL_SAMPLE:
                init_weight = tf.fill(
                    batch_shape,
                    tf.constant(
                        -np.inf,
                        dtype=current_step_meta_info.init_energy.dtype))
            else:
                init_weight = tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE)

            init_momentum_cumsum = [
                tf.zeros_like(x) for x in initial_state.momentum
            ]
            initial_state_candidate = TreeDoublingStateCandidate(
                state=initial_state.state,
                target=initial_state.target,
                target_grad_parts=initial_state.target_grad_parts,
                energy=initial_state.target,
                weight=init_weight)
            energy_diff_sum = tf.zeros_like(current_step_meta_info.init_energy,
                                            name='energy_diff_sum')
            [
                _,
                energy_diff_tree_sum,
                momentum_tree_cumsum,
                leapfrogs_taken,
                final_state,
                candidate_tree_state,
                final_continue_tree,
                final_not_divergence,
                momentum_state_memory,
            ] = tf.while_loop(
                cond=lambda iter_, energy_diff_sum, init_momentum_cumsum,  # pylint: disable=g-long-lambda
                leapfrogs_taken, state, state_c, continue_tree, not_divergence,
                momentum_state_memory: (
                    (iter_ < nsteps) & tf.reduce_any(continue_tree)),
                body=lambda iter_, energy_diff_sum, init_momentum_cumsum,  # pylint: disable=g-long-lambda
                leapfrogs_taken, state, state_c, continue_tree, not_divergence,
                momentum_state_memory: (self._loop_build_sub_tree(
                    directions, integrator, current_step_meta_info, iter_,
                    energy_diff_sum, init_momentum_cumsum, leapfrogs_taken,
                    state, state_c, continue_tree, not_divergence,
                    momentum_state_memory)),
                loop_vars=(
                    tf.zeros([], dtype=tf.int32, name='iter'),
                    energy_diff_sum,
                    init_momentum_cumsum,
                    tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE),
                    initial_state,
                    initial_state_candidate,
                    continue_tree,
                    not_divergence,
                    momentum_state_memory,
                ),
                parallel_iterations=TF_WHILE_PARALLEL_ITERATIONS,
            )

        return (
            candidate_tree_state,
            final_state,
            final_not_divergence,
            final_continue_tree,
            energy_diff_tree_sum,
            momentum_tree_cumsum,
            leapfrogs_taken,
        )
Beispiel #4
0
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='Moyal'):
        """Construct Moyal distributions with location and scale `loc` and `scale`.

    The parameters `loc` and `scale` must be shaped in a way that supports
    broadcasting (e.g. `loc + scale` is a valid operation).

    Args:
      loc: Floating point tensor, the means of the distribution(s).
      scale: Floating point tensor, the scales of the distribution(s).
        scale must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value `NaN` to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: `'Moyal'`.

    Raises:
      TypeError: if loc and scale are different dtypes.


    #### References

    [1] J.E. Moyal, "XXX. Theory of ionization fluctuations",
       The London, Edinburgh, and Dublin Philosophical Magazine
       and Journal of Science.
       https://www.tandfonline.com/doi/abs/10.1080/14786440308521076
    [2] G. Cordeiro, J. Nobre, R. Pescim, E. Ortega,
        "The beta Moyal: a useful skew distribution",
        https://www.arpapress.com/Volumes/Vol10Issue2/IJRRAS_10_2_02.pdf
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale],
                                            dtype_hint=tf.float32)
            loc = tensor_util.convert_nonref_to_tensor(loc,
                                                       name='loc',
                                                       dtype=dtype)
            scale = tensor_util.convert_nonref_to_tensor(scale,
                                                         name='scale',
                                                         dtype=dtype)
            dtype_util.assert_same_float_dtype([loc, scale])
            # Positive scale is asserted by the incorporated Moyal bijector.
            self._moyal_bijector = moyal_cdf_bijector.MoyalCDF(
                loc=loc, scale=scale, validate_args=validate_args)

            # Because the uniform sampler generates samples in `[0, 1)` this would
            # cause samples to lie in `(inf, -inf]` instead of `(inf, -inf)`. To fix
            # this, we use `np.finfo(dtype_util.as_numpy_dtype(self.dtype).tiny`
            # because it is the smallest, positive, 'normal' number.
            super(Moyal, self).__init__(
                # TODO(b/137665504): Use batch-adding meta-distribution to set the
                # batch shape instead of tf.ones.
                distribution=uniform.Uniform(low=np.finfo(
                    dtype_util.as_numpy_dtype(dtype)).tiny,
                                             high=tf.ones([], dtype=dtype),
                                             allow_nan_stats=allow_nan_stats),
                # The Moyal bijector encodes the CDF function as the forward,
                # and hence needs to be inverted.
                bijector=invert_bijector.Invert(self._moyal_bijector,
                                                validate_args=validate_args),
                parameters=parameters,
                name=name)
Beispiel #5
0
def sample_lkj(num_samples,
               dimension,
               concentration,
               cholesky_space=False,
               seed=None,
               name=None):
    """Returns a Tensor of samples from an LKJ distribution.

  Args:
    num_samples: Python `int`. The number of samples to draw.
    dimension: Python `int`. The dimension of correlation matrices.
    concentration: `Tensor` representing the concentration of the LKJ
      distribution.
    cholesky_space: Python `bool`. Whether to take samples from LKJ or
      Chol(LKJ).
    seed: Python integer seed for RNG
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    samples: A Tensor of correlation matrices (or Cholesky factors of
      correlation matrices if `cholesky_space = True`) with shape
      `[n] + B + [D, D]`, where `B` is the shape of the `concentration`
      parameter, and `D` is the `dimension`.

  Raises:
    ValueError: If `dimension` is negative.
  """
    if dimension < 0:
        raise ValueError(
            'Cannot sample negative-dimension correlation matrices.')
    # Notation below: B is the batch shape, i.e., tf.shape(concentration)

    # We need 1 seed for beta corr12, and 2 per loop iter.
    num_seeds = 1 + 2 * max(0, dimension - 2)
    seeds = list(samplers.split_seed(seed, n=num_seeds, salt='sample_lkj'))
    with tf.name_scope('sample_lkj' or name):
        concentration = tf.convert_to_tensor(concentration)
        if not dtype_util.is_floating(concentration.dtype):
            raise TypeError(
                'The concentration argument should have floating type, not '
                '{}'.format(dtype_util.name(concentration.dtype)))

        concentration = _replicate(num_samples, concentration)
        concentration_shape = ps.shape(concentration)
        if dimension <= 1:
            # For any dimension <= 1, there is only one possible correlation matrix.
            shape = ps.concat([concentration_shape, [dimension, dimension]],
                              axis=0)
            return tf.ones(shape=shape, dtype=concentration.dtype)
        beta_conc = concentration + (dimension - 2.) / 2.
        beta_dist = beta.Beta(concentration1=beta_conc,
                              concentration0=beta_conc)

        # Note that the sampler below deviates from [1], by doing the sampling in
        # cholesky space. This does not change the fundamental logic of the
        # sampler, but does speed up the sampling.

        # This is the correlation coefficient between the first two dimensions.
        # This is also `r` in reference [1].
        corr12 = 2. * beta_dist.sample(seed=seeds.pop()) - 1.

        # Below we construct the Cholesky of the initial 2x2 correlation matrix,
        # which is of the form:
        # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the
        # first two dimensions.
        # This is the top-left corner of the cholesky of the final sample.
        first_row = tf.concat([
            tf.ones_like(corr12)[..., tf.newaxis],
            tf.zeros_like(corr12)[..., tf.newaxis]
        ],
                              axis=-1)
        second_row = tf.concat(
            [corr12[..., tf.newaxis],
             tf.sqrt(1 - corr12**2)[..., tf.newaxis]],
            axis=-1)

        chol_result = tf.concat(
            [first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :]],
            axis=-2)

        for n in range(2, dimension):
            # Loop invariant: on entry, result has shape B + [n, n]
            beta_conc = beta_conc - 0.5
            # norm is y in reference [1].
            norm = beta.Beta(concentration1=n / 2.,
                             concentration0=beta_conc).sample(seed=seeds.pop())
            # distance shape: B + [1] for broadcast
            distance = tf.sqrt(norm)[..., tf.newaxis]
            # direction is u in reference [1].
            # direction shape: B + [n]
            direction = _uniform_unit_norm(n,
                                           concentration_shape,
                                           concentration.dtype,
                                           seed=seeds.pop())
            # raw_correlation is w in reference [1].
            raw_correlation = distance * direction  # shape: B + [n]

            # This is the next row in the cholesky of the result,
            # which differs from the construction in reference [1].
            # In the reference, the new row `z` = chol_result @ raw_correlation^T
            # = C @ raw_correlation^T (where as short hand we use C = chol_result).
            # We prove that the below equation is the right row to add to the
            # cholesky, by showing equality with reference [1].
            # Let S be the sample constructed so far, and let `z` be as in
            # reference [1]. Then at this iteration, the new sample S' will be
            # [[S z^T]
            #  [z 1]]
            # In our case we have the cholesky decomposition factor C, so
            # we want our new row x (same size as z) to satisfy:
            #  [[S z^T]  [[C 0]    [[C^T  x^T]         [[CC^T  Cx^T]
            #   [z 1]] =  [x k]]    [0     k]]  =       [xC^t   xx^T + k**2]]
            # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible,
            # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k
            # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 -
            # distance**2).
            new_row = tf.concat(
                [raw_correlation,
                 tf.sqrt(1. - norm[..., tf.newaxis])],
                axis=-1)

            # Finally add this new row, by growing the cholesky of the result.
            chol_result = tf.concat([
                chol_result,
                tf.zeros_like(chol_result[..., 0][..., tf.newaxis])
            ],
                                    axis=-1)

            chol_result = tf.concat([chol_result, new_row[..., tf.newaxis, :]],
                                    axis=-2)

        assert not seeds, 'Did not use all seeds: ' + len(seeds)
        if cholesky_space:
            return chol_result

        result = tf.matmul(chol_result, chol_result, transpose_b=True)
        # The diagonal for a correlation matrix should always be ones. Due to
        # numerical instability the matmul might not achieve that, so manually set
        # these to ones.
        result = tf.linalg.set_diag(
            result, tf.ones(shape=ps.shape(result)[:-1], dtype=result.dtype))
        # This sampling algorithm can produce near-PSD matrices on which standard
        # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals`
        # fail. Specifically, as documented in b/116828694, around 2% of trials
        # of 900,000 5x5 matrices (distributed according to 9 different
        # concentration parameter values) contained at least one matrix on which
        # the Cholesky decomposition failed.
        return result
    def __init__(self,
                 distribution,
                 bijector,
                 batch_shape=None,
                 event_shape=None,
                 kwargs_split_fn=_default_kwargs_split_fn,
                 validate_args=False,
                 parameters=None,
                 name=None):
        """Construct a Transformed Distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      bijector: The object responsible for calculating the transformation.
        Typically an instance of `Bijector`.
      batch_shape: `integer` vector `Tensor` which overrides `distribution`
        `batch_shape`; valid only if `distribution.is_scalar_batch()`.
      event_shape: `integer` vector `Tensor` which overrides `distribution`
        `event_shape`; valid only if `distribution.is_scalar_event()`.
      kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns
        a tuple of kwargs `dict`s for each of the `distribution` and `bijector`
        parameters respectively.
        Default value: `_default_kwargs_split_fn` (i.e.,
            `lambda kwargs: (kwargs.get('distribution_kwargs', {}),
                             kwargs.get('bijector_kwargs', {}))`)
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      parameters: Locals dict captured by subclass constructor, to be used for
        copy/slice re-instantiation operations.
      name: Python `str` name prefixed to Ops created by this class. Default:
        `bijector.name + distribution.name`.
    """
        parameters = dict(locals()) if parameters is None else parameters
        name = name or (("" if bijector is None else bijector.name) +
                        (distribution.name or ""))
        with tf.name_scope(name) as name:
            self._kwargs_split_fn = (_default_kwargs_split_fn
                                     if kwargs_split_fn is None else
                                     kwargs_split_fn)
            # For convenience we define some handy constants.
            self._zero = tf.constant(0, dtype=tf.int32, name="zero")
            self._empty = tf.constant([], dtype=tf.int32, name="empty")

            # We will keep track of a static and dynamic version of
            # self._is_{batch,event}_override. This way we can do more prior to graph
            # execution, including possibly raising Python exceptions.

            self._override_batch_shape = self._maybe_validate_shape_override(
                batch_shape, distribution.is_scalar_batch(), validate_args,
                "batch_shape")
            self._is_batch_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_batch_shape),
                    self._zero))
            self._is_maybe_batch_override = bool(
                tf.get_static_value(self._override_batch_shape) is None
                or tf.get_static_value(self._override_batch_shape).size != 0)

            self._override_event_shape = self._maybe_validate_shape_override(
                event_shape, distribution.is_scalar_event(), validate_args,
                "event_shape")
            self._is_event_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_event_shape),
                    self._zero))
            self._is_maybe_event_override = bool(
                tf.get_static_value(self._override_event_shape) is None
                or tf.get_static_value(self._override_event_shape).size != 0)

            # To convert a scalar distribution into a multivariate distribution we
            # will draw dims from the sample dims, which are otherwise iid. This is
            # easy to do except in the case that the base distribution has batch dims
            # and we're overriding event shape. When that case happens the event dims
            # will incorrectly be to the left of the batch dims. In this case we'll
            # cyclically permute left the new dims.
            self._needs_rotation = prefer_static.reduce_all([
                self._is_event_override,
                prefer_static.logical_not(self._is_batch_override),
                prefer_static.logical_not(distribution.is_scalar_batch())
            ])
            override_event_ndims = prefer_static.rank_from_shape(
                self._override_event_shape)
            self._rotate_ndims = _pick_scalar_condition(
                self._needs_rotation, override_event_ndims, 0)
            # We'll be reducing the head dims (if at all), i.e., this will be []
            # if we don't need to reduce.
            self._reduce_event_indices = tf.range(
                self._rotate_ndims - override_event_ndims, self._rotate_ndims)

        self._distribution = distribution
        self._bijector = bijector
        super(TransformedDistribution, self).__init__(
            dtype=self._distribution.dtype,
            reparameterization_type=self._distribution.reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=self._distribution.allow_nan_stats,
            parameters=parameters,
            # We let TransformedDistribution access _graph_parents since this class
            # is more like a baseclass than derived.
            graph_parents=(
                distribution._graph_parents +  # pylint: disable=protected-access
                bijector.graph_parents),
            name=name)
Beispiel #7
0
 def _z(self, x):
   """Standardize input `x`."""
   with tf.name_scope("standardize"):
     return (x - self.loc) / self.scale
Beispiel #8
0
def maybe_update_along_axis(*,
                            tensor,
                            new_tensor,
                            axis,
                            ind,
                            do_update,
                            dtype=None,
                            name=None):
    """Replace `tensor` entries with `new_tensor` along a given axis.

  This updates elements of `tensor` that correspond to the elements returned by
  `numpy.take(updated, ind, axis)` with the corresponding elements of
  `new_tensor`.

  # Example
  ```python
  tensor = tf.ones([5, 4, 3, 2])
  new_tensor = tf.zeros([5, 4, 3, 2])
  updated_tensor = maybe_update_along_axis(tensor=tensor,
                                           new_tensor=new_tensor,
                                           axis=1,
                                           ind=2,
                                           do_update=True)
  # Returns a `Tensor` of ones where
  # `updated_tensor[:, 2, :, :].numpy() == 0`
  ```
  If the `do_update` is set to `False`, then the update does not happen unless
  the number of dimensions along the `axis` is equal to 1. This functionality
  is useful when, for example, aggregating samples of an Ito process.

  Args:
    tensor: A `Tensor` of any shape and `dtype`.
    new_tensor: A `Tensor` of the same `dtype` as `tensor` and of shape
      broadcastable with `tensor`.
    axis: A Python integer. The axis of `tensor` along which the elements have
      to be updated.
    ind: An int32 scalar `Tensor` that denotes an index on the `axis` which
      defines the updated slice of `tensor` (see example above).
    do_update: A bool scalar `Tensor`. If `False`, the output is the same as
      `tensor`, unless  the dimension of the `tensor` along the `axis` is equal
      to 1.
    dtype: The `dtype` of the input `Tensor`s.
      Default value: `None` which means that default dtypes inferred by
        TensorFlow are used.
    name: Python string. The name to give this op.
      Default value: `None` which maps to `maybe_update_along_axis`.

  Returns:
    A `Tensor` of the same shape and `dtype` as `tensor`.
  """
    name = name or 'maybe_update_along_axis'
    with tf.name_scope(name):
        tensor = tf.convert_to_tensor(tensor, dtype=dtype, name='tensor')
        dtype = tensor.dtype
        new_tensor = tf.convert_to_tensor(new_tensor,
                                          dtype=dtype,
                                          name='new_tensor')
        ind = tf.convert_to_tensor(ind, name='ind')
        do_update = tf.convert_to_tensor(do_update, name='do_update')
        size_along_axis = tensor.shape.as_list()[axis]

        def _write_update_to_result():
            size_along_axis_dynamic = tf.shape(tensor)[axis]
            one_hot = tf.one_hot(ind, depth=size_along_axis_dynamic)
            mask_size = tensor.shape.rank
            mask_shape = tf.pad([size_along_axis_dynamic],
                                paddings=[[axis, mask_size - axis - 1]],
                                constant_values=1)
            mask = tf.reshape(one_hot > 0, mask_shape)
            return tf.where(mask, new_tensor, tensor)

        # Update only if size_along_axis > 1 or if the shape is dynamic
        if size_along_axis is None or size_along_axis > 1:
            return tf.cond(do_update, _write_update_to_result, lambda: tensor)
        else:
            return new_tensor
Beispiel #9
0
def generate_mc_normal_draws(num_normal_draws,
                             num_time_steps,
                             num_sample_paths,
                             random_type,
                             skip=0,
                             seed=None,
                             dtype=None,
                             name=None):
    """Generates normal random samples to be consumed by a Monte Carlo algorithm.

  Many of Monte Carlo (MC) algorithms can be re-written so that all necessary
  random (or quasi-random) variables are drawn in advance as a `Tensor` of
  shape `[num_time_steps, num_samples, num_normal_draws]`, where
  `num_time_steps` is the number of time steps Monte Carlo algorithm performs,
  `num_sample_paths` is a number of sample paths of the Monte Carlo algorithm
  and `num_normal_draws` is a number of independent normal draws per sample
  paths.
  For example, in order to use quasi-random numbers in a Monte Carlo algorithm,
  the samples have to be drawn in advance.
  The function generates a `Tensor`, say, `x` in a format such that for a
  quasi-`random_type` `x[i]` is correspond to different dimensions of the
  quasi-random sequence, so that it can be used in a Monte Carlo algorithm

  Args:
    num_normal_draws: A scalar int32 `Tensor`. The number of independent normal
      draws at each time step for each sample path. Should be a graph
      compilation constant.
    num_time_steps: A scalar int32 `Tensor`. The number of time steps at which
      to draw the independent normal samples. Should be a graph compilation
      constant.
    num_sample_paths: A scalar int32 `Tensor`. The number of trajectories (e.g.,
      Monte Carlo paths) for which to draw the independent normal samples.
      Should be a graph compilation constant.
    random_type: Enum value of `tff.math.random.RandomType`. The type of
      (quasi)-random number generator to use to generate the paths.
    skip: `int32` 0-d `Tensor`. The number of initial points of the Sobol or
      Halton sequence to skip. Used only when `random_type` is 'SOBOL',
      'HALTON', or 'HALTON_RANDOMIZED', otherwise ignored.
      Default value: `0`.
      seed: Seed for the random number generator. The seed is
        only relevant if `random_type` is one of
        `[STATELESS, PSEUDO, HALTON_RANDOMIZED, PSEUDO_ANTITHETIC,
          STATELESS_ANTITHETIC]`. For `PSEUDO`, `PSEUDO_ANTITHETIC` and
        `HALTON_RANDOMIZED` the seed should be an Python integer. For
        `STATELESS` and  `STATELESS_ANTITHETIC `must be supplied as an integer
        `Tensor` of shape `[2]`.
        Default value: `None` which means no seed is set.
    dtype: The `dtype` of the output `Tensor`.
      Default value: `None` which maps to `float32`.
    name: Python string. The name to give this op.
      Default value: `None` which maps to `generate_mc_normal_draws`.

  Returns:
   A `Tensor` of shape `[num_time_steps, num_sample_paths, num_normal_draws]`.
  """
    if name is None:
        name = 'generate_mc_normal_draws'
    if skip is None:
        skip = 0
    with tf.name_scope(name):
        if dtype is None:
            dtype = tf.float32
        # In case of quasi-random draws, the total dimension of the draws should be
        # `num_time_steps * dim`
        total_dimension = tf.zeros([num_time_steps * num_normal_draws],
                                   dtype=dtype,
                                   name='total_dimension')
        normal_draws = random.mv_normal_sample([num_sample_paths],
                                               mean=total_dimension,
                                               random_type=random_type,
                                               seed=seed,
                                               skip=skip)
        # Reshape and transpose
        normal_draws = tf.reshape(
            normal_draws, [num_sample_paths, num_time_steps, num_normal_draws])
        # Shape [steps_num, num_samples, dim]
        normal_draws = tf.transpose(normal_draws, [1, 0, 2])
        return normal_draws
Beispiel #10
0
    def __init__(self,
                 df,
                 scale_operator,
                 input_output_cholesky=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name=None):
        """Construct Wishart distributions.

    Args:
      df: `float` or `double` tensor, the degrees of freedom of the
        distribution(s). `df` must be greater than or equal to `k`.
      scale_operator: `float` or `double` instance of `LinearOperator`.
      input_output_cholesky: Python `bool`. If `True`, functions whose input or
        output have the semantics of samples assume inputs are in Cholesky form
        and return outputs in Cholesky form. In particular, if this flag is
        `True`, input to `log_prob` is presumed of Cholesky form and output from
        `sample`, `mean`, and `mode` are of Cholesky form.  Setting this
        argument to `True` is purely a computational optimization and does not
        change the underlying distribution; for instance, `mean` returns the
        Cholesky of the mean, not the mean of Cholesky factors. The `variance`
        and `stddev` methods are unaffected by this flag.
        Default value: `False` (i.e., input/output does not have Cholesky
        semantics).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if scale is not floating-type
      TypeError: if scale.dtype != df.dtype
      ValueError: if df < k, where scale operator event shape is
        `(k, k)`
    """
        parameters = dict(locals())
        self._input_output_cholesky = input_output_cholesky
        with tf.name_scope(name) as name:
            with tf.name_scope("init"):
                if not dtype_util.is_floating(scale_operator.dtype):
                    raise TypeError(
                        "scale_operator.dtype=%s is not a floating-point type"
                        % scale_operator.dtype)
                if not scale_operator.is_square:
                    print(scale_operator.to_dense().eval())
                    raise ValueError("scale_operator must be square.")

                self._scale_operator = scale_operator
                self._df = tf.convert_to_tensor(df,
                                                dtype=scale_operator.dtype,
                                                name="df")
                dtype_util.assert_same_float_dtype(
                    [self._df, self._scale_operator])
                if tf.compat.dimension_value(
                        self._scale_operator.shape[-1]) is None:
                    self._dimension = tf.cast(
                        self._scale_operator.domain_dimension_tensor(),
                        dtype=self._scale_operator.dtype,
                        name="dimension")
                else:
                    self._dimension = tf.convert_to_tensor(
                        tf.compat.dimension_value(
                            self._scale_operator.shape[-1]),
                        dtype=self._scale_operator.dtype,
                        name="dimension")
                df_val = tf.get_static_value(self._df)
                dim_val = tf.get_static_value(self._dimension)
                if df_val is not None and dim_val is not None:
                    df_val = np.asarray(df_val)
                    if not df_val.shape:
                        df_val = [df_val]
                    if np.any(df_val < dim_val):
                        raise ValueError(
                            "Degrees of freedom (df = %s) cannot be less than "
                            "dimension of scale matrix (scale.dimension = %s)"
                            % (df_val, dim_val))
                elif validate_args:
                    assertions = assert_util.assert_less_equal(
                        self._dimension,
                        self._df,
                        message=("Degrees of freedom (df = %s) cannot be "
                                 "less than dimension of scale matrix "
                                 "(scale.dimension = %s)" %
                                 (self._dimension, self._df)))
                    self._df = distribution_util.with_dependencies(
                        [assertions], self._df)
        super(_WishartLinearOperator, self).__init__(
            dtype=self._scale_operator.dtype,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            reparameterization_type=reparameterization.FULLY_REPARAMETERIZED,
            parameters=parameters,
            graph_parents=([self._df, self._dimension] +
                           self._scale_operator.graph_parents),
            name=name)
Beispiel #11
0
  def __init__(self,
               distributions,
               dtype_override=None,
               validate_args=False,
               allow_nan_stats=False,
               name='Blockwise'):
    """Construct the `Blockwise` distribution.

    Args:
      distributions: Python `list` of `tfp.distributions.Distribution`
        instances. All distribution instances must have the same `batch_shape`
        and all must have `event_ndims==1`, i.e., be vector-variate
        distributions.
      dtype_override: samples of `distributions` will be cast to this `dtype`.
        If unspecified, all `distributions` must have the same `dtype`.
        Default value: `None` (i.e., do not cast).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      self._assertions = _maybe_validate_distributions(
          distributions, dtype_override, validate_args)

      if dtype_override is not None:
        dtype = dtype_override
      else:
        dtype = set(
            dtype_util.base_dtype(d.dtype)
            for d in distributions
            if d.dtype is not None)
        if len(dtype) == 0:  # pylint: disable=g-explicit-length-test
          dtype = tf.float32
        elif len(dtype) == 1:
          dtype = dtype.pop()
        else:
          # Shouldn't be here: we already threw an exception in
          # `_maybe_validate_distributions`.
          raise ValueError('Internal Error: unable to resolve `dtype`.')

      reparameterization_type = set(d.reparameterization_type
                                    for d in distributions)
      reparameterization_type = (reparameterization_type.pop()
                                 if len(reparameterization_type) == 1
                                 else reparameterization.NOT_REPARAMETERIZED)

      self._distributions = distributions
      super(Blockwise, self).__init__(
          dtype=dtype,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          reparameterization_type=reparameterization_type,
          parameters=parameters,
          graph_parents=_model_flatten(d._graph_parents for d in distributions),  # pylint: disable=protected-access
          name=name)
Beispiel #12
0
    def __init__(self,
                 df,
                 scale=None,
                 scale_tril=None,
                 input_output_cholesky=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Wishart"):
        """Construct Wishart distributions.

    Args:
      df: `float` or `double` `Tensor`. Degrees of freedom, must be greater than
        or equal to dimension of the scale matrix.
      scale: `float` or `double` `Tensor`. The symmetric positive definite
        scale matrix of the distribution. Exactly one of `scale` and
        'scale_tril` must be passed.
      scale_tril: `float` or `double` `Tensor`. The Cholesky factorization
        of the symmetric positive definite scale matrix of the distribution.
        Exactly one of `scale` and 'scale_tril` must be passed.
      input_output_cholesky: Python `bool`. If `True`, functions whose input or
        output have the semantics of samples assume inputs are in Cholesky form
        and return outputs in Cholesky form. In particular, if this flag is
        `True`, input to `log_prob` is presumed of Cholesky form and output from
        `sample`, `mean`, and `mode` are of Cholesky form.  Setting this
        argument to `True` is purely a computational optimization and does not
        change the underlying distribution; for instance, `mean` returns the
        Cholesky of the mean, not the mean of Cholesky factors. The `variance`
        and `stddev` methods are unaffected by this flag.
        Default value: `False` (i.e., input/output does not have Cholesky
        semantics).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    Raises:
      ValueError: if zero or both of 'scale' and 'scale_tril' are passed in.
    """
        parameters = dict(locals())

        with tf.name_scope(name) as name:
            with tf.name_scope("init"):
                if (scale is None) == (scale_tril is None):
                    raise ValueError(
                        "Must pass scale or scale_tril, but not both.")

                dtype = dtype_util.common_dtype([df, scale, scale_tril],
                                                tf.float32)
                df = tf.convert_to_tensor(df, name="df", dtype=dtype)
                if scale is not None:
                    scale = tf.convert_to_tensor(scale,
                                                 name="scale",
                                                 dtype=dtype)
                    if validate_args:
                        scale = distribution_util.assert_symmetric(scale)
                    scale_tril = tf.linalg.cholesky(scale)
                else:  # scale_tril is not None
                    scale_tril = tf.convert_to_tensor(scale_tril,
                                                      name="scale_tril",
                                                      dtype=dtype)
                    if validate_args:
                        scale_tril = distribution_util.with_dependencies([
                            assert_util.assert_positive(
                                tf.linalg.diag_part(scale_tril),
                                message="scale_tril must be positive definite"
                            ),
                            assert_util.assert_equal(
                                tf.shape(scale_tril)[-1],
                                tf.shape(scale_tril)[-2],
                                message="scale_tril must be square")
                        ], scale_tril)

            super(Wishart, self).__init__(
                df=df,
                scale_operator=tf.linalg.LinearOperatorLowerTriangular(
                    tril=scale_tril,
                    is_non_singular=True,
                    is_positive_definite=True,
                    is_square=True),
                input_output_cholesky=input_output_cholesky,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                name=name)
        self._parameters = parameters
Beispiel #13
0
    def __init__(self,
                 shift=None,
                 scale_identity_multiplier=None,
                 scale_diag=None,
                 scale_tril=None,
                 scale_perturb_factor=None,
                 scale_perturb_diag=None,
                 adjoint=False,
                 validate_args=False,
                 name="affine",
                 dtype=None):
        """Instantiates the `Affine` bijector.

    This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments,
    giving the forward operation:

    ```none
    Y = g(X) = scale @ X + shift
    ```

    where the `scale` term is logically equivalent to:

    ```python
    scale = (
      scale_identity_multiplier * tf.diag(tf.ones(d)) +
      tf.diag(scale_diag) +
      scale_tril +
      scale_perturb_factor @ diag(scale_perturb_diag) @
        tf.transpose([scale_perturb_factor])
    )
    ```

    If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are
    specified then `scale += IdentityMatrix`. Otherwise specifying a
    `scale` argument has the semantics of `scale += Expand(arg)`, i.e.,
    `scale_diag != None` means `scale += tf.diag(scale_diag)`.

    Args:
      shift: Floating-point `Tensor`. If this is set to `None`, no shift is
        applied.
      scale_identity_multiplier: floating point rank 0 `Tensor` representing a
        scaling done to the identity matrix.
        When `scale_identity_multiplier = scale_diag = scale_tril = None` then
        `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added
        to `scale`.
      scale_diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape `[N1, N2, ...  k]`, which represents a k x k
        diagonal matrix.
        When `None` no diagonal term is added to `scale`.
      scale_tril: Floating-point `Tensor` representing the lower triangular
        matrix. `scale_tril` has shape `[N1, N2, ...  k, k]`, which represents a
        k x k lower triangular matrix.
        When `None` no `scale_tril` term is added to `scale`.
        The upper triangular elements above the diagonal are ignored.
      scale_perturb_factor: Floating-point `Tensor` representing factor matrix
        with last two dimensions of shape `(k, r)`. When `None`, no rank-r
        update is added to `scale`.
      scale_perturb_diag: Floating-point `Tensor` representing the diagonal
        matrix. `scale_perturb_diag` has shape `[N1, N2, ...  r]`, which
        represents an `r x r` diagonal matrix. When `None` low rank updates will
        take the form `scale_perturb_factor * scale_perturb_factor.T`.
      adjoint: Python `bool` indicating whether to use the `scale` matrix as
        specified or its adjoint.
        Default value: `False`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
      dtype: `tf.DType` to prefer when converting args to `Tensor`s. Else, we
        fall back to a common dtype inferred from the args, finally falling back
        to float32.

    Raises:
      ValueError: if `perturb_diag` is specified but not `perturb_factor`.
      TypeError: if `shift` has different `dtype` from `scale` arguments.
    """
        # Ambiguous definition of low rank update.
        if scale_perturb_diag is not None and scale_perturb_factor is None:
            raise ValueError("When scale_perturb_diag is specified, "
                             "scale_perturb_factor must be specified.")

        # Special case, only handling a scaled identity matrix. We don't know its
        # dimensions, so this is special cased.
        # We don't check identity_multiplier, since below we set it to 1. if all
        # other scale args are None.
        self._is_only_identity_multiplier = (scale_tril is None
                                             and scale_diag is None
                                             and scale_perturb_factor is None)

        with tf.name_scope(name) as name:
            self._name = name
            self._validate_args = validate_args

            if dtype is None:
                dtype = dtype_util.common_dtype([
                    shift, scale_identity_multiplier, scale_diag, scale_tril,
                    scale_perturb_diag, scale_perturb_factor
                ], tf.float32)

            if shift is not None:
                shift = tf.convert_to_tensor(shift, name="shift", dtype=dtype)
            self._shift = shift

            # When no args are specified, pretend the scale matrix is the identity
            # matrix.
            if (self._is_only_identity_multiplier
                    and scale_identity_multiplier is None):
                scale_identity_multiplier = tf.convert_to_tensor(1.,
                                                                 dtype=dtype)

            # self._create_scale_operator returns a LinearOperator in all cases
            # except if self._is_only_identity_multiplier; in which case it
            # returns a scalar Tensor.
            scale = self._create_scale_operator(
                identity_multiplier=scale_identity_multiplier,
                diag=scale_diag,
                tril=scale_tril,
                perturb_diag=scale_perturb_diag,
                perturb_factor=scale_perturb_factor,
                shift=shift,
                validate_args=validate_args,
                dtype=dtype)

            if scale is not None and not self._is_only_identity_multiplier:
                if (shift is not None and
                        not dtype_util.base_equal(shift.dtype, scale.dtype)):
                    raise TypeError(
                        "shift.dtype({}) is incompatible with scale.dtype({})."
                        .format(shift.dtype, scale.dtype))

            self._scale = scale
            self._adjoint = adjoint
            super(Affine, self).__init__(forward_min_event_ndims=1,
                                         is_constant_jacobian=True,
                                         dtype=dtype,
                                         validate_args=validate_args,
                                         name=name)
Beispiel #14
0
    def bootstrap_results(self, init_state):
        """Returns an object with the same type as returned by `one_step`.

    Args:
      init_state: `Tensor` or Python `list` of `Tensor`s representing the
        initial state(s) of the Markov chain(s).

    Returns:
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(
                mcmc_util.make_name(self.name, 'remc', 'bootstrap_results')):
            init_state, unused_is_multipart_state = mcmc_util.prepare_state_parts(
                init_state)

            inverse_temperatures = tf.convert_to_tensor(
                self.inverse_temperatures, name='inverse_temperatures')

            if self._state_includes_replicas:
                it_n_replica = inverse_temperatures.shape[0]
                state_n_replica = init_state[0].shape[0]
                if ((it_n_replica is not None)
                        and (state_n_replica is not None)
                        and (it_n_replica != state_n_replica)):
                    raise ValueError(
                        'Number of replicas implied by initial state ({}) must equal '
                        'number of replicas implied by inverse_temperatures ({}), but '
                        'did not'.format(state_n_replica, it_n_replica))

            # We will now replicate each of a possible batch of initial stats, one for
            # each inverse_temperature. So if init_state=[x, y] of shapes [Sx, Sy]
            # then the new shape is [(T, Sx), (T, Sy)] where (a, b) means
            # concatenation and T=shape(inverse_temperature).
            num_replica = ps.size0(inverse_temperatures)
            replica_shape = ps.convert_to_shape_tensor([num_replica])

            if self._state_includes_replicas:
                replica_states = init_state
            else:
                replica_states = [
                    tf.broadcast_to(  # pylint: disable=g-complex-comprehension
                        x,
                        ps.concat([replica_shape, ps.shape(x)], axis=0),
                        name='replica_states') for x in init_state
                ]

            target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
                target_log_prob_fn=self.target_log_prob_fn,
                inverse_temperatures=inverse_temperatures,
                untempered_log_prob_fn=self.untempered_log_prob_fn,
                tempered_log_prob_fn=self.tempered_log_prob_fn,
            )
            # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                raise TypeError(
                    '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a second '
                    '(`seed`) argument. `TransitionKernel` instances now receive seeds '
                    'via `one_step`.')

            replica_results = inner_kernel.bootstrap_results(replica_states)

            pre_swap_replica_target_log_prob = _get_field(
                replica_results, 'target_log_prob')

            replica_and_batch_shape = ps.shape(
                pre_swap_replica_target_log_prob)
            batch_shape = replica_and_batch_shape[1:]

            inverse_temperatures = bu.left_justified_broadcast_to(
                inverse_temperatures, replica_and_batch_shape)

            # Pretend we did a "null swap", which will always be accepted.
            swaps = bu.left_justified_broadcast_to(tf.range(num_replica),
                                                   replica_and_batch_shape)
            # is_swap_accepted.shape = [n_replica, n_replica] + batch_shape.
            is_swap_accepted = distribution_util.rotate_transpose(tf.eye(
                num_replica, batch_shape=batch_shape, dtype=tf.bool),
                                                                  shift=2)

            return ReplicaExchangeMCKernelResults(
                post_swap_replica_states=replica_states,
                pre_swap_replica_results=replica_results,
                post_swap_replica_results=_set_swapped_fields_to_nan(
                    replica_results),
                is_swap_proposed=is_swap_accepted,
                is_swap_accepted=is_swap_accepted,
                is_swap_proposed_adjacent=_sub_diag(is_swap_accepted),
                is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
                inverse_temperatures=self.inverse_temperatures,
                swaps=swaps,
                step_count=tf.zeros(shape=(), dtype=tf.int32),
                seed=samplers.zeros_seed(),
                potential_energy=tf.zeros_like(
                    pre_swap_replica_target_log_prob),
            )
Beispiel #15
0
    def state_y(self,
                t: types.RealTensor,
                name: str = None) -> types.RealTensor:
        """Computes the state variable `y(t)` for tha Gaussian HJM Model.

    For Gaussian HJM model, the state parameter y(t), can be analytically
    computed as follows:

    y_ij(t) = exp(-k_i * t) * exp(-k_j * t) * (
              int_0^t rho_ij * sigma_i(u) * sigma_j(u) * du)

    Args:
      t: A rank 1 real `Tensor` of shape `[num_times]` specifying the time `t`.
      name: Python string. The name to give to the ops created by this function.
        Default value: `None` which maps to the default name `state_y`.

    Returns:
      A real `Tensor` of shape [self._factors, self._factors, num_times]
      containing the computed y_ij(t).
    """
        name = name or 'state_y'
        with tf.name_scope(name):
            t = tf.convert_to_tensor(t, dtype=self._dtype)
            t_shape = tf.shape(t)
            t = tf.broadcast_to(t, tf.concat([[self._dim], t_shape], axis=0))
            time_index = tf.searchsorted(self._jump_locations, t)
            # create a matrix k2(i,j) = k(i) + k(j)
            mr2 = tf.expand_dims(self._mean_reversion, axis=-1)
            # Add a dimension corresponding to `num_times`
            mr2 = tf.expand_dims(mr2 + tf.transpose(mr2), axis=-1)

            def _integrate_volatility_squared(vol, l_limit, u_limit):
                # create sigma2_ij = sigma_i * sigma_j
                vol = tf.expand_dims(vol, axis=-2)
                vol_squared = tf.expand_dims(self._rho, axis=-1) * (
                    vol * tf.transpose(vol, perm=[1, 0, 2]))
                return vol_squared / mr2 * (tf.math.exp(mr2 * u_limit) -
                                            tf.math.exp(mr2 * l_limit))

            is_constant_vol = tf.math.equal(
                tf.shape(self._jump_values_vol)[-1], 0)
            v_squared_between_vol_knots = tf.cond(
                is_constant_vol,
                lambda: tf.zeros(shape=(self._dim, self._dim, 0),
                                 dtype=self._dtype),
                lambda: _integrate_volatility_squared(  # pylint: disable=g-long-lambda
                    self._jump_values_vol, self._padded_knots, self.
                    _jump_locations))
            v_squared_at_vol_knots = tf.concat([
                tf.zeros((self._dim, self._dim, 1), dtype=self._dtype),
                utils.cumsum_using_matvec(v_squared_between_vol_knots)
            ],
                                               axis=-1)

            vn = tf.concat([self._zero_padding, self._jump_locations], axis=1)

            v_squared_t = _integrate_volatility_squared(
                self._volatility(t), tf.gather(vn, time_index, batch_dims=1),
                t)
            v_squared_t += tf.gather(v_squared_at_vol_knots,
                                     time_index,
                                     batch_dims=-1)

            return tf.math.exp(-mr2 * t) * v_squared_t
Beispiel #16
0
    def __init__(self,
                 df,
                 kernel,
                 index_points=None,
                 mean_fn=None,
                 observation_noise_variance=0.,
                 marginal_fn=None,
                 cholesky_fn=None,
                 jitter=1e-6,
                 validate_args=False,
                 allow_nan_stats=False,
                 name='StudentTProcess'):
        """Instantiate a StudentTProcess Distribution.

    Args:
      df: Positive Floating-point `Tensor` representing the degrees of freedom.
        Must be greater than 2.
      kernel: `PositiveSemidefiniteKernel`-like instance representing the
        TP's covariance function.
      index_points: `float` `Tensor` representing finite (batch of) vector(s) of
        points in the index set over which the TP is defined. Shape has the form
        `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature
        dimensions and must equal `kernel.feature_ndims` and `e` is the number
        (size) of index points in each batch. Ultimately this distribution
        corresponds to a `e`-dimensional multivariate Student's T. The batch
        shape must be broadcastable with `kernel.batch_shape` and any batch dims
        yielded by `mean_fn`.
      mean_fn: Python `callable` that acts on `index_points` to produce a (batch
        of) vector(s) of mean values at `index_points`. Takes a `Tensor` of
        shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is
        broadcastable with `[b1, ..., bB]`. Default value: `None` implies
        constant zero function.
      observation_noise_variance: `float` `Tensor` representing (batch of)
        scalar variance(s) of the noise in the Normal likelihood
        distribution of the model. If batched, the batch shape must be
        broadcastable with the shapes of all other batched parameters
        (`kernel.batch_shape`, `index_points`, etc.).
        Default value: `0.`
      marginal_fn: A Python callable that takes a location, covariance matrix,
        optional `validate_args`, `allow_nan_stats` and `name` arguments, and
        returns a multivariate normal subclass of `tfd.Distribution`.
        Default value: `None`, in which case a Cholesky-factorizing function
        is created using `make_cholesky_factored_marginal_fn` and the
        `jitter` argument.
      cholesky_fn: Callable which takes a single (batch) matrix argument and
        returns a Cholesky-like lower triangular factor.  Default value: `None`,
        in which case `make_cholesky_with_jitter_fn` is used with the `jitter`
        parameter. At most one of `cholesky_fn` and `marginal_fn` should be set.
      jitter: `float` scalar `Tensor` added to the diagonal of the covariance
        matrix to ensure positive definiteness of the covariance matrix.
        This argument is ignored if `cholesky_fn` is set.
        Default value: `1e-6`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `False`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "StudentTProcess".

    Raises:
      ValueError: if `mean_fn` is not `None` and is not callable.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype(
                [df, kernel, index_points, observation_noise_variance, jitter],
                tf.float32)
            df = tensor_util.convert_nonref_to_tensor(df,
                                                      dtype=dtype,
                                                      name='df')
            observation_noise_variance = tensor_util.convert_nonref_to_tensor(
                observation_noise_variance,
                dtype=dtype,
                name='observation_noise_variance')
            index_points = tensor_util.convert_nonref_to_tensor(
                index_points, dtype=dtype, name='index_points')
            jitter = tensor_util.convert_nonref_to_tensor(jitter,
                                                          dtype=dtype,
                                                          name='jitter')

            self._kernel = kernel
            self._index_points = index_points
            # Default to a constant zero function, borrowing the dtype from
            # index_points to ensure consistency.
            if mean_fn is None:
                mean_fn = lambda x: tf.zeros([1], dtype=dtype)
            else:
                if not callable(mean_fn):
                    raise ValueError('`mean_fn` must be a Python callable')
            self._df = df
            self._observation_noise_variance = observation_noise_variance
            self._mean_fn = mean_fn
            self._jitter = jitter
            self._cholesky_fn = cholesky_fn
            if marginal_fn is not None and cholesky_fn is not None:
                raise ValueError(
                    'At most one of `marginal_fn` and `cholesky_fn` should be set.'
                )
            if marginal_fn is None:
                if self._cholesky_fn is None:
                    self._cholesky_fn = cholesky_util.make_cholesky_with_jitter_fn(
                        jitter)
                self._marginal_fn = make_cholesky_factored_marginal_fn(
                    self._cholesky_fn)
            else:
                self._marginal_fn = marginal_fn

            with tf.name_scope('init'):
                super(StudentTProcess, self).__init__(
                    dtype=dtype,
                    reparameterization_type=reparameterization.
                    FULLY_REPARAMETERIZED,
                    validate_args=validate_args,
                    allow_nan_stats=allow_nan_stats,
                    parameters=parameters,
                    name=name)
  def __init__(self,
               short_position: types.BoolTensor,
               currency: Union[types.CurrencyProtoType,
                               List[types.CurrencyProtoType]],
               expiry_date: types.DateTensor,
               equity: List[str],
               contract_amount: types.FloatTensor,
               strike: types.FloatTensor,
               is_call_option: List[bool],
               business_day_convention: types.BusinessDayConventionProtoType,
               calendar: types.BankHolidaysProtoType,
               settlement_days: Optional[types.IntTensor] = 0,
               discount_curve_type: curve_types_lib.CurveType = None,
               discount_curve_mask: types.IntTensor = None,
               equity_mask: types.IntTensor = None,
               config: Union[AmericanOptionConfig, Dict[str, Any]] = None,
               batch_names: Optional[types.StringTensor] = None,
               dtype: Optional[types.Dtype] = None,
               name: Optional[str] = None):
    """Initializes the batch of American Equity Options.

    Args:
      short_position: Whether the price is computed for the contract holder.
        Default value: `True` which means that the price is for the contract
        holder.
      currency: The denominated currency.
      expiry_date: A `DateTensor` specifying the dates on which the options
        expire.
      equity: A string name of the underlyings.
      contract_amount: A `Tensor` of real dtype and shape compatible with
        with `short_position`.
      strike: `Tensor` of real dtype and shape compatible with
        with `short_position`. Option strikes.
      is_call_option: A bool `Tensor` of shape compatible with with
        `short_position`. Indicates which options are of call type.
      business_day_convention: A business count convention.
      calendar: A calendar to specify the weekend mask and bank holidays.
      settlement_days: An integer `Tensor` of the shape broadcastable with the
        shape of `fixing_date`.
      discount_curve_type: An optional instance of `CurveType` or a list of
        those. If supplied as a list and `discount_curve_mask` is not supplied,
        the size of the list should be the same as the number of priced
        instruments. Defines discount curves for the instruments.
        Default value: `None`, meaning that discount curves are inferred
        from `currency` and `config`.
      discount_curve_mask: An optional integer `Tensor` of values ranging from
        `0` to `len(discount_curve_type) - 1` and of shape `batch_shape`.
        Identifies a mapping between `discount_curve_type` list and the
        underlying instruments.
        Default value: `None`.
      equity_mask: An optional integer `Tensor` of values ranging from
        `0` to `len(equity) - 1` and of shape `batch_shape`. Identifies
        a mapping between `equity` list and the underlying instruments.
        Default value: `None`.
      config: Optional `AmericanOptionConfig` or a dictionary. If dictionary,
        then the keys should be the same as the field names of
        `AmericanOptionConfig`.
      batch_names: A string `Tensor` of instrument names. Should be of shape
        `batch_shape + [2]` specying name and instrument type. This is useful
        when the `from_protos` method is used and the user needs to identify
        which instruments got batched together.
      dtype: `tf.Dtype` of the input and output real `Tensor`s.
        Default value: `None` which maps to `float64`.
      name: Python str. The name to give to the ops created by this class.
        Default value: `None` which maps to 'AmericanOption'.
    """
    self._name = name or "AmericanOption"
    with tf.name_scope(self._name):
      if batch_names is not None:
        self._names = tf.convert_to_tensor(batch_names,
                                           name="batch_names")
      else:
        self._names = None
      self._dtype = dtype or tf.float64
      ones = tf.constant(1, dtype=self._dtype)
      self._short_position = tf.where(
          short_position, ones, -ones, name="short_position")
      self._contract_amount = tf.convert_to_tensor(
          contract_amount, dtype=self._dtype, name="contract_amount")
      self._strike = tf.convert_to_tensor(strike, dtype=self._dtype,
                                          name="strike")
      self._is_call_option = tf.convert_to_tensor(
          is_call_option, dtype=tf.bool, name="strike")
      settlement_days = tf.convert_to_tensor(settlement_days)
      # Business day roll convention and the end of month flag
      roll_convention, eom = market_data_utils.get_business_day_convention(
          business_day_convention)
      # TODO(b/160446193): Calendar is ignored at the moment
      calendar = dateslib.create_holiday_calendar(
          weekend_mask=dateslib.WeekendMask.SATURDAY_SUNDAY)
      if isinstance(expiry_date, types.IntTensor):
        self._expiry_date = dateslib.dates_from_tensor(expiry_date)
      else:
        self._expiry_date = dateslib.convert_to_date_tensor(expiry_date)
      self._settlement_days = settlement_days
      self._roll_convention = roll_convention
      # Get discount and reference curves
      self._currency = cashflow_streams.to_list(currency)
      self._equity = cashflow_streams.to_list(equity)
      if len(self._currency) != len(self._equity):
        if len(self._currency) > 1 and len(self._equity) > 1:
          raise ValueError(
              "Number of currencies and equities should be the same "
              "but it is {0} and {1}".format(len(self._currency),
                                             len(self._equity)))

      config = _process_config(config)
      [
          self._model,
          self._num_samples,
          self._seed,
          self._num_exercise_times,
          self._num_calibration_samples
      ] = _get_config_values(config)

      if discount_curve_type is None:
        discount_curve_type = []
        for currency in self._currency:
          if currency in config.discounting_curve:
            curve_type = config.discounting_curve[currency]
          else:
            # Default discounting curve
            curve_type = curve_types_lib.RiskFreeCurve(
                currency=currency)
          discount_curve_type.append(curve_type)

      # Get masks for discount curves and vol surfaces
      [
          self._discount_curve_type,
          self._discount_curve_mask
      ] = cashflow_streams.process_curve_types(discount_curve_type,
                                               discount_curve_mask)
      [
          self._equity,
          self._equity_mask,
      ] = equity_utils.process_equities(self._equity, equity_mask)
      # Get batch shape
      self._batch_shape = tf.shape(strike)
Beispiel #18
0
    def __init__(
        self,
        ndims=2,
        curvature=0.03,
        name='banana',
        pretty_name='Banana',
    ):
        """Construct the banana model.

    Args:
      ndims: Python integer. Dimensionality of the distribution. Must be at
        least 2.
      curvature: Python float. Controls the strength of the curvature of
        the distribution.
      name: Python `str` name prefixed to Ops created by this class.
      pretty_name: A Python `str`. The pretty name of this model.

    Raises:
      ValueError: If ndims < 2.
    """
        if ndims < 2:
            raise ValueError('ndims must be at least 2, saw: {}'.format(ndims))

        with tf.name_scope(name):

            def bijector_fn(x):
                """Banana transform."""
                batch_shape = tf.shape(x)[:-1]
                shift = tf.concat(
                    [
                        tf.zeros(tf.concat([batch_shape, [1]], axis=0)),
                        curvature * (tf.square(x[..., :1]) - 100),
                        tf.zeros(tf.concat([batch_shape, [ndims - 2]],
                                           axis=0)),
                    ],
                    axis=-1,
                )
                return tfb.Shift(shift)

            mg = tfd.MultivariateNormalDiag(loc=tf.zeros(ndims),
                                            scale_diag=[10.] + [1.] *
                                            (ndims - 1))
            banana = tfd.TransformedDistribution(
                mg,
                bijector=tfb.MaskedAutoregressiveFlow(bijector_fn=bijector_fn))

            sample_transformations = {
                'identity':
                model.Model.SampleTransformation(
                    fn=lambda params: params,
                    pretty_name='Identity',
                    # The second dimension is a sum of scaled Chi2 and normal
                    # distribution.
                    # Mean of Chi2 with one degree of freedom is 1, but since the
                    # first element has variance of 100, it cancels with the shift
                    # (which is why the shift is there).
                    ground_truth_mean=onp.zeros(ndims),
                    # Variance of Chi2 with one degree of freedom is 2.
                    ground_truth_standard_deviation=onp.array(
                        [10.] + [onp.sqrt(1. + 2 * curvature**2 * 10.**4)] +
                        [1.] * (ndims - 2)),
                )
            }

        self._banana = banana

        super(Banana, self).__init__(
            default_event_space_bijector=tfb.Identity(),
            event_shape=banana.event_shape,
            dtype=banana.dtype,
            name=name,
            pretty_name=pretty_name,
            sample_transformations=sample_transformations,
        )
Beispiel #19
0
 def _z(self, x, loc=None, scale=None):
   """Standardize input `x`."""
   loc = tf.convert_to_tensor(self.loc if loc is None else loc)
   scale = tf.convert_to_tensor(self.scale if scale is None else scale)
   with tf.name_scope('standardize'):
     return (x - loc) / scale
Beispiel #20
0
def _convolution_batch_nhwbc(x, kernel, rank, strides, padding, dilations,
                             name):
    """Specialization of batch conv to NHWBC data format."""
    with tf.name_scope(name or 'conv2d_nhwbc'):
        # Prepare arguments.
        [
            rank,
            _,  # strides
            padding,
            dilations,
            _,  # data_format
        ] = prepare_conv_args(rank, strides, padding, dilations)
        strides = prepare_strides(strides, rank + 2, arg_name='strides')

        dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, dtype=dtype, name='x')
        kernel = tf.convert_to_tensor(kernel, dtype=dtype, name='kernel')

        # Step 1: Transpose and double flatten kernel.
        # kernel.shape = B + F + [c, c']. Eg: [b, fh, fw, c, c']
        kernel_shape = prefer_static.shape(kernel)
        kernel_batch_shape, kernel_event_shape = prefer_static.split(
            kernel_shape, num_or_size_splits=[-1, rank + 2])
        kernel_batch_size = prefer_static.reduce_prod(kernel_batch_shape)
        kernel_ndims = prefer_static.rank(kernel)
        kernel_batch_ndims = kernel_ndims - rank - 2
        perm = prefer_static.concat([
            prefer_static.range(kernel_batch_ndims, kernel_batch_ndims + rank),
            prefer_static.range(0, kernel_batch_ndims),
            prefer_static.range(kernel_batch_ndims + rank, kernel_ndims),
        ],
                                    axis=0)  # Eg, [1, 2, 0, 3, 4]
        kernel = tf.transpose(kernel, perm=perm)  # F + B + [c, c']
        kernel = tf.reshape(kernel,
                            shape=prefer_static.concat([
                                kernel_event_shape[:rank],
                                [
                                    kernel_batch_size * kernel_event_shape[-2],
                                    kernel_event_shape[-1]
                                ],
                            ],
                                                       axis=0))  # F + [bc, c']

        # Step 2: Double flatten x.
        # x.shape = N + D + B + [c]
        x_shape = prefer_static.shape(x)
        [
            x_sample_shape,
            x_rank_shape,
            x_batch_shape,
            x_channel_shape,
        ] = prefer_static.split(
            x_shape, num_or_size_splits=[-1, rank, kernel_batch_ndims, 1])
        x = tf.reshape(
            x,  # N + D + B + [c]
            shape=prefer_static.concat([
                [prefer_static.reduce_prod(x_sample_shape)],
                x_rank_shape,
                [
                    prefer_static.reduce_prod(x_batch_shape) *
                    prefer_static.reduce_prod(x_channel_shape)
                ],
            ],
                                       axis=0))  # [n] + D + [bc]

        # Step 3: Apply convolution.
        y = tf.nn.depthwise_conv2d(x,
                                   kernel,
                                   strides=strides,
                                   padding=padding,
                                   data_format='NHWC',
                                   dilations=dilations)
        #  SAME: y.shape = [n, h,      w,      bcc']
        # VALID: y.shape = [n, h-fh+1, w-fw+1, bcc']

        # Step 4: Reshape/reduce for output.
        y_shape = prefer_static.shape(y)
        y = tf.reshape(y,
                       shape=prefer_static.concat(
                           [
                               x_sample_shape,
                               y_shape[1:-1],
                               kernel_batch_shape,
                               kernel_event_shape[-2:],
                           ],
                           axis=0))  # N + D' + B + [c, c']
        y = tf.reduce_sum(y, axis=-2)  # N + D' + B + [c']

        return y
Beispiel #21
0
 def _inv_z(self, z):
   """Reconstruct input `x` from a its normalized version."""
   with tf.name_scope("reconstruct"):
     return z * self.scale + self.loc
Beispiel #22
0
def _interp_regular_1d_grid_impl(x,
                                 x_ref_min,
                                 x_ref_max,
                                 y_ref,
                                 axis=-1,
                                 batch_y_ref=False,
                                 fill_value='constant_extension',
                                 fill_value_below=None,
                                 fill_value_above=None,
                                 grid_regularizing_transform=None,
                                 name=None):
    """1-D interpolation that works with/without batching."""
    # Note: we do *not* make the no-batch version a special case of the batch
    # version, because that would an inefficient use of batch_gather with
    # unnecessarily broadcast args.
    with tf.name_scope(name or 'interp_regular_1d_grid_impl'):

        # Arg checking.
        allowed_fv_st = ('constant_extension', 'extrapolate')
        for fv in (fill_value, fill_value_below, fill_value_above):
            if isinstance(fv, str) and fv not in allowed_fv_st:
                raise ValueError(
                    'A fill value ({}) was not an allowed string ({})'.format(
                        fv, allowed_fv_st))

        # Separate value fills for below/above incurs extra cost, so keep track of
        # whether this is needed.
        need_separate_fills = (
            fill_value_above is not None or fill_value_below is not None or
            fill_value == 'extrapolate'  # always requries separate below/above
        )
        if need_separate_fills and fill_value_above is None:
            fill_value_above = fill_value
        if need_separate_fills and fill_value_below is None:
            fill_value_below = fill_value

        dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
                                        dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, name='x', dtype=dtype)

        x_ref_min = tf.convert_to_tensor(x_ref_min,
                                         name='x_ref_min',
                                         dtype=dtype)
        x_ref_max = tf.convert_to_tensor(x_ref_max,
                                         name='x_ref_max',
                                         dtype=dtype)
        if not batch_y_ref:
            _assert_ndims_statically(x_ref_min, expect_ndims=0)
            _assert_ndims_statically(x_ref_max, expect_ndims=0)

        y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype)

        if batch_y_ref:
            # If we're batching,
            #   x.shape ~ [A1,...,AN, D],  x_ref_min/max.shape ~ [A1,...,AN]
            # So to add together we'll append a singleton.
            # If not batching, x_ref_min/max are scalar, so this isn't an issue,
            # moreover, if not batching, x can be scalar, and expanding x_ref_min/max
            # would cause a bad expansion of x when added to x (confused yet?).
            x_ref_min = x_ref_min[..., tf.newaxis]
            x_ref_max = x_ref_max[..., tf.newaxis]

        axis = tf.convert_to_tensor(axis, name='axis', dtype=tf.int32)
        axis = prefer_static.non_negative_axis(axis, tf.rank(y_ref))
        _assert_ndims_statically(axis, expect_ndims=0)

        ny = tf.cast(tf.shape(y_ref)[axis], dtype)

        # Map [x_ref_min, x_ref_max] to [0, ny - 1].
        # This is the (fractional) index of x.
        if grid_regularizing_transform is None:
            g = lambda x: x
        else:
            g = grid_regularizing_transform
        fractional_idx = ((g(x) - g(x_ref_min)) /
                          (g(x_ref_max) - g(x_ref_min)))
        x_idx_unclipped = fractional_idx * (ny - 1)

        # Wherever x is NaN, x_idx_unclipped will be NaN as well.
        # Keep track of the nan indices here (so we can impute NaN later).
        # Also eliminate any NaN indices, since there is not NaN in 32bit.
        nan_idx = tf.math.is_nan(x_idx_unclipped)
        zero = tf.zeros((), dtype=dtype)
        x_idx_unclipped = tf.where(nan_idx, zero, x_idx_unclipped)
        x_idx = tf.clip_by_value(x_idx_unclipped, zero, ny - 1)

        # Get the index above and below x_idx.
        # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx),
        # however, this results in idx_below == idx_above whenever x is on a grid.
        # This in turn results in y_ref_below == y_ref_above, and then the gradient
        # at this point is zero.  So here we 'jitter' one of idx_below, idx_above,
        # so that they are at different values.  This jittering does not affect the
        # interpolated value, but does make the gradient nonzero (unless of course
        # the y_ref values are the same).
        idx_below = tf.floor(x_idx)
        idx_above = tf.minimum(idx_below + 1, ny - 1)
        idx_below = tf.maximum(idx_above - 1, 0)

        # These are the values of y_ref corresponding to above/below indices.
        idx_below_int32 = tf.cast(idx_below, dtype=tf.int32)
        idx_above_int32 = tf.cast(idx_above, dtype=tf.int32)
        if batch_y_ref:
            # If y_ref.shape ~ [A1,...,AN, C, B1,...,BN],
            # and x.shape, x_ref_min/max.shape ~ [A1,...,AN, D]
            # Then y_ref_below.shape ~ [A1,...,AN, D, B1,...,BN]
            y_ref_below = _batch_gather_with_broadcast(y_ref, idx_below_int32,
                                                       axis)
            y_ref_above = _batch_gather_with_broadcast(y_ref, idx_above_int32,
                                                       axis)
        else:
            # Here, y_ref_below.shape =
            #   y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:]
            y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis)
            y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis)

        # Use t to get a convex combination of the below/above values.
        t = x_idx - idx_below

        # x, and tensors shaped like x, need to be added to, and selected with
        # (using tf.where) the output y.  This requires appending singletons.
        # Make functions appropriate for batch/no-batch.
        if batch_y_ref:
            # In the non-batch case, the output shape is going to be
            #   y_ref.shape[:axis] + x.shape + y_ref.shape[axis+1:]
            expand_x_fn = _make_expand_x_fn_for_batch_interpolation(
                y_ref, axis)
        else:
            # In the batch case, the output shape is going to be
            #   Broadcast(y_ref.shape[:axis], x.shape[:-1]) +
            #   x.shape[-1:] +  y_ref.shape[axis+1:]
            expand_x_fn = _make_expand_x_fn_for_non_batch_interpolation(
                y_ref, axis)

        t = expand_x_fn(t)
        nan_idx = expand_x_fn(nan_idx, broadcast=True)
        x_idx_unclipped = expand_x_fn(x_idx_unclipped, broadcast=True)

        y = t * y_ref_above + (1 - t) * y_ref_below

        # Now begins a long excursion to fill values outside [x_min, x_max].

        # Re-insert NaN wherever x was NaN.
        y = tf.where(nan_idx, tf.constant(np.nan, y.dtype), y)

        if not need_separate_fills:
            if fill_value == 'constant_extension':
                pass  # Already handled by clipping x_idx_unclipped.
            else:
                y = tf.where(
                    (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1),
                    fill_value, y)
        else:
            # Fill values below x_ref_min <==> x_idx_unclipped < 0.
            if fill_value_below == 'constant_extension':
                pass  # Already handled by the clipping that created x_idx_unclipped.
            elif fill_value_below == 'extrapolate':
                if batch_y_ref:
                    # For every batch member, gather the first two elements of y across
                    # `axis`.
                    y_0 = tf.gather(y_ref, [0], axis=axis)
                    y_1 = tf.gather(y_ref, [1], axis=axis)
                else:
                    # If not batching, we want to gather the first two elements, just like
                    # above.  However, these results need to be replicated for every
                    # member of x.  An easy way to do that is to gather using
                    # indices = zeros/ones(x.shape).
                    y_0 = tf.gather(y_ref,
                                    tf.zeros(tf.shape(x), dtype=tf.int32),
                                    axis=axis)
                    y_1 = tf.gather(y_ref,
                                    tf.ones(tf.shape(x), dtype=tf.int32),
                                    axis=axis)
                x_delta = (x_ref_max - x_ref_min) / (ny - 1)
                x_factor = expand_x_fn((x - x_ref_min) / x_delta,
                                       broadcast=True)
                y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0),
                             y)
            else:
                y = tf.where(x_idx_unclipped < 0, fill_value_below, y)
            # Fill values above x_ref_min <==> x_idx_unclipped > ny - 1.
            if fill_value_above == 'constant_extension':
                pass  # Already handled by the clipping that created x_idx_unclipped.
            elif fill_value_above == 'extrapolate':
                ny_int32 = tf.shape(y_ref)[axis]
                if batch_y_ref:
                    y_n1 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 1],
                                     axis=axis)
                    y_n2 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 2],
                                     axis=axis)
                else:
                    y_n1 = tf.gather(y_ref,
                                     tf.fill(tf.shape(x), ny_int32 - 1),
                                     axis=axis)
                    y_n2 = tf.gather(y_ref,
                                     tf.fill(tf.shape(x), ny_int32 - 2),
                                     axis=axis)
                x_delta = (x_ref_max - x_ref_min) / (ny - 1)
                x_factor = expand_x_fn((x - x_ref_max) / x_delta,
                                       broadcast=True)
                y = tf.where(x_idx_unclipped > ny - 1,
                             y_n1 + x_factor * (y_n1 - y_n2), y)
            else:
                y = tf.where(x_idx_unclipped > ny - 1, fill_value_above, y)

        return y
Beispiel #23
0
 def __init__(self, validate_args=False, name="tanh"):
     with tf.name_scope(name) as name:
         super(Tanh, self).__init__(forward_min_event_ndims=0,
                                    validate_args=validate_args,
                                    name=name)
Beispiel #24
0
def batch_interp_regular_nd_grid(x,
                                 x_ref_min,
                                 x_ref_max,
                                 y_ref,
                                 axis,
                                 fill_value='constant_extension',
                                 name=None):
    """Multi-linear interpolation on a regular (constant spacing) grid.

  Given [a batch of] reference values, this function computes a multi-linear
  interpolant and evaluates it on [a batch of] of new `x` values.

  The interpolant is built from reference values indexed by `nd` dimensions
  of `y_ref`, starting at `axis`.

  For example, take the case of a `2-D` scalar valued function and no leading
  batch dimensions.  In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]`
  is the reference value corresponding to grid point

  ```
  [x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1),
   x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)]
  ```

  In the general case, dimensions to the left of `axis` in `y_ref` are broadcast
  with leading dimensions in `x`, `x_ref_min`, `x_ref_max`.

  Args:
    x: Numeric `Tensor` The x-coordinates of the interpolated output values for
      each batch.  Shape `[..., D, nd]`, designating [a batch of] `D`
      coordinates in `nd` space.  `D` must be `>= 1` and is not a batch dim.
    x_ref_min:  `Tensor` of same `dtype` as `x`.  The minimum values of the
      (implicitly defined) reference `x_ref`.  Shape `[..., nd]`.
    x_ref_max:  `Tensor` of same `dtype` as `x`.  The maximum values of the
      (implicitly defined) reference `x_ref`.  Shape `[..., nd]`.
    y_ref:  `Tensor` of same `dtype` as `x`.  The reference output values. Shape
      `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference
      values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued
      function (for `M >= 0`).
    axis:  Scalar integer `Tensor`.  Dimensions `[axis, axis + nd)` of `y_ref`
      index the interpolation table.  E.g. `3-D` interpolation of a scalar
      valued function requires `axis=-3` and a `3-D` matrix valued function
      requires `axis=-5`.
    fill_value:  Determines what values output should take for `x` values that
      are below `x_ref_min` or above `x_ref_max`. Scalar `Tensor` or
      'constant_extension' ==> Extend as constant function.
      Default value: `'constant_extension'`
    name:  A name to prepend to created ops.
      Default value: `'batch_interp_regular_nd_grid'`.

  Returns:
    y_interp:  Interpolation between members of `y_ref`, at points `x`.
      `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].`

  Raises:
    ValueError:  If `rank(x) < 2` is determined statically.
    ValueError:  If `axis` is not a scalar is determined statically.
    ValueError:  If `axis + nd > rank(y_ref)` is determined statically.

  #### Examples

  Interpolate a function of one variable.

  ```python
  y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20))

  tfp.math.batch_interp_regular_nd_grid(
      # x.shape = [3, 1], x_ref_min/max.shape = [1].  Trailing `1` for `1-D`.
      x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref,
      axis=0)
  ==> approx [exp(6.0), exp(0.5), exp(3.3)]
  ```

  Interpolate a scalar function of two variables.

  ```python
  x_ref_min = [0., 0.]
  x_ref_max = [2 * np.pi, 2 * np.pi]

  # Build y_ref.
  x0s, x1s = tf.meshgrid(
      tf.linspace(x_ref_min[0], x_ref_max[0], num=100),
      tf.linspace(x_ref_min[1], x_ref_max[1], num=100),
      indexing='ij')

  def func(x0, x1):
    return tf.sin(x0) * tf.cos(x1)

  y_ref = func(x0s, x1s)

  x = np.pi * tf.random_uniform(shape=(10, 2))

  tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2)
  ==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])
  ```

  """
    with tf.name_scope(name or 'interp_regular_nd_grid'):
        dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
                                        dtype_hint=tf.float32)

        # Arg checking.
        if isinstance(fill_value, str):
            if fill_value != 'constant_extension':
                raise ValueError(
                    'A fill value ({}) was not an allowed string ({})'.format(
                        fill_value, 'constant_extension'))
        else:
            fill_value = tf.convert_to_tensor(fill_value,
                                              name='fill_value',
                                              dtype=dtype)
            _assert_ndims_statically(fill_value, expect_ndims=0)

        # x.shape = [..., nd].
        x = tf.convert_to_tensor(x, name='x', dtype=dtype)
        _assert_ndims_statically(x, expect_ndims_at_least=2)

        # y_ref.shape = [..., C1,...,Cnd, B1,...,BM]
        y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype)

        # x_ref_min.shape = [nd]
        x_ref_min = tf.convert_to_tensor(x_ref_min,
                                         name='x_ref_min',
                                         dtype=dtype)
        x_ref_max = tf.convert_to_tensor(x_ref_max,
                                         name='x_ref_max',
                                         dtype=dtype)
        _assert_ndims_statically(x_ref_min,
                                 expect_ndims_at_least=1,
                                 expect_static=True)
        _assert_ndims_statically(x_ref_max,
                                 expect_ndims_at_least=1,
                                 expect_static=True)

        # nd is the number of dimensions indexing the interpolation table, it's the
        # 'nd' in the function name.
        nd = tf.compat.dimension_value(x_ref_min.shape[-1])
        if nd is None:
            raise ValueError('`x_ref_min.shape[-1]` must be known statically.')
        tensorshape_util.assert_is_compatible_with(x_ref_max.shape[-1:],
                                                   x_ref_min.shape[-1:])

        # Convert axis and check it statically.
        axis = tf.convert_to_tensor(axis, dtype=tf.int32, name='axis')
        axis = prefer_static.non_negative_axis(axis, tf.rank(y_ref))
        tensorshape_util.assert_has_rank(axis.shape, 0)
        axis_ = tf.get_static_value(axis)
        y_ref_rank_ = tf.get_static_value(tf.rank(y_ref))
        if axis_ is not None and y_ref_rank_ is not None:
            if axis_ + nd > y_ref_rank_:
                raise ValueError(
                    'Since dims `[axis, axis + nd)` index the interpolation table, we '
                    'must have `axis + nd <= rank(y_ref)`.  Found: '
                    '`axis`: {},  rank(y_ref): {}, and inferred `nd` from trailing '
                    'dimensions of `x_ref_min` to be {}.'.format(
                        axis_, y_ref_rank_, nd))

        x_batch_shape = tf.shape(x)[:-2]
        x_ref_min_batch_shape = tf.shape(x_ref_min)[:-1]
        x_ref_max_batch_shape = tf.shape(x_ref_max)[:-1]
        y_ref_batch_shape = tf.shape(y_ref)[:axis]

        # Do a brute-force broadcast of batch dims (add zeros).
        batch_shape = y_ref_batch_shape
        for tensor in [
                x_batch_shape, x_ref_min_batch_shape, x_ref_max_batch_shape
        ]:
            batch_shape = tf.broadcast_dynamic_shape(batch_shape, tensor)

        def _batch_of_zeros_with_rightmost_singletons(n_singletons):
            """Return Tensor of zeros with some singletons on the rightmost dims."""
            ones = tf.ones(shape=[n_singletons], dtype=tf.int32)
            return tf.zeros(shape=tf.concat([batch_shape, ones], axis=0),
                            dtype=dtype)

        x += _batch_of_zeros_with_rightmost_singletons(n_singletons=2)
        x_ref_min += _batch_of_zeros_with_rightmost_singletons(n_singletons=1)
        x_ref_max += _batch_of_zeros_with_rightmost_singletons(n_singletons=1)
        y_ref += _batch_of_zeros_with_rightmost_singletons(
            n_singletons=tf.rank(y_ref) - axis)

        return _batch_interp_with_gather_nd(
            x=x,
            x_ref_min=x_ref_min,
            x_ref_max=x_ref_max,
            y_ref=y_ref,
            nd=nd,
            fill_value=fill_value,
            batch_dims=tf.get_static_value(tf.rank(x)) - 2)
Beispiel #25
0
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 max_tree_depth=10,
                 max_energy_diff=1000.,
                 unrolled_leapfrog_steps=1,
                 seed=None,
                 name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The
        maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e.
        the number of nodes in a binary tree `max_tree_depth` nodes deep. The
        default setting of 10 takes up to 1024 leapfrog steps.
      max_energy_diff: Scaler threshold of energy differences at each leapfrog,
        divergence samples are defined as leapfrog steps that exceed this
        threshold. Default to 1000.
      unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree
        expansion step. Applies a direct linear multipler to the maximum
        trajectory length implied by max_tree_depth. Defaults to 1.
      seed: Python integer to seed the random number generator.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'nuts_kernel').
    """
        with tf.name_scope(name or 'NoUTurnSampler') as name:
            # Process `max_tree_depth` argument.
            max_tree_depth = tf.get_static_value(max_tree_depth)
            if max_tree_depth is None or max_tree_depth < 1:
                raise ValueError(
                    'max_tree_depth must be known statically and >= 1 but was '
                    '{}'.format(max_tree_depth))
            self._max_tree_depth = max_tree_depth

            # Compute parameters derived from `max_tree_depth`.
            instruction_array = build_tree_uturn_instruction(max_tree_depth,
                                                             init_memory=-1)
            [write_instruction_numpy, read_instruction_numpy
             ] = generate_efficient_write_read_instruction(instruction_array)

            # TensorArray version of the read/write instruction need to be created
            # within the function call to be compatible with XLA. Here we store the
            # numpy version of the instruction and convert it to TensorArray later.
            self._write_instruction = write_instruction_numpy
            self._read_instruction = read_instruction_numpy

            # Process all other arguments.
            self._target_log_prob_fn = target_log_prob_fn
            if not tf.nest.is_nested(step_size):
                step_size = [step_size]
            step_size = [
                tf.convert_to_tensor(s, dtype_hint=tf.float32)
                for s in step_size
            ]
            self._step_size = step_size

            self._parameters = dict(
                target_log_prob_fn=target_log_prob_fn,
                step_size=step_size,
                max_tree_depth=max_tree_depth,
                max_energy_diff=max_energy_diff,
                unrolled_leapfrog_steps=unrolled_leapfrog_steps,
                seed=seed,
                name=name,
            )
            self._seed_stream = SeedStream(seed, salt='nuts_one_step')
            self._unrolled_leapfrog_steps = unrolled_leapfrog_steps
            self._name = name
            self._max_energy_diff = max_energy_diff
Beispiel #26
0
    def __init__(self,
                 dim: int,
                 mean_reversion: types.RealTensor,
                 volatility: Union[types.RealTensor,
                                   Callable[..., types.RealTensor]],
                 initial_discount_rate_fn,
                 corr_matrix: types.RealTensor = None,
                 dtype: tf.DType = None,
                 name: str = None):
        """Initializes the HJM model.

    Args:
      dim: A Python scalar which corresponds to the number of factors comprising
        the model.
      mean_reversion: A real positive `Tensor` of shape `[dim]`. Corresponds to
        the mean reversion rate of each factor.
      volatility: A real positive `Tensor` of the same `dtype` and shape as
        `mean_reversion` or a callable with the following properties: (a)  The
          callable should accept a scalar `Tensor` `t` and returns a 1-D
          `Tensor` of shape `[dim]`. The function returns instantaneous
          volatility `sigma(t)`. When `volatility` is specified is a real
          `Tensor`, each factor is assumed to have a constant instantaneous
          volatility. Corresponds to the instantaneous volatility of each
          factor.
      initial_discount_rate_fn: A Python callable that accepts expiry time as a
        real `Tensor` of the same `dtype` as `mean_reversion` and returns a
        `Tensor` of shape `input_shape`. Corresponds to the zero coupon bond
        yield at the present time for the input expiry time.
      corr_matrix: A `Tensor` of shape `[dim, dim]` and the same `dtype` as
        `mean_reversion`. Corresponds to the correlation matrix `Rho`.
      dtype: The default dtype to use when converting values to `Tensor`s.
        Default value: `None` which maps to `tf.float32`.
      name: Python string. The name to give to the ops created by this class.
        Default value: `None` which maps to the default name
          `gaussian_hjm_model`.
    """
        self._name = name or 'gaussian_hjm_model'
        with tf.name_scope(self._name):
            self._dtype = dtype or tf.float32
            self._dim = dim
            self._factors = dim

            def _instant_forward_rate_fn(t):
                t = tf.convert_to_tensor(t, dtype=self._dtype)

                def _log_zero_coupon_bond(x):
                    r = tf.convert_to_tensor(initial_discount_rate_fn(x),
                                             dtype=self._dtype)
                    return -r * x

                rate = -gradient.fwd_gradient(
                    _log_zero_coupon_bond,
                    t,
                    use_gradient_tape=True,
                    unconnected_gradients=tf.UnconnectedGradients.ZERO)
                return rate

            def _initial_discount_rate_fn(t):
                return tf.convert_to_tensor(initial_discount_rate_fn(t),
                                            dtype=self._dtype)

            self._instant_forward_rate_fn = _instant_forward_rate_fn
            self._initial_discount_rate_fn = _initial_discount_rate_fn
            self._mean_reversion = tf.convert_to_tensor(mean_reversion,
                                                        dtype=dtype,
                                                        name='mean_reversion')

            self._batch_shape = []
            self._batch_rank = 0

            # Setup volatility
            if callable(volatility):
                self._volatility = volatility
            else:
                volatility = tf.convert_to_tensor(volatility, dtype=dtype)
                jump_locations = [[]] * dim
                volatility = tf.expand_dims(volatility, axis=-1)
                self._volatility = piecewise.PiecewiseConstantFunc(
                    jump_locations=jump_locations,
                    values=volatility,
                    dtype=dtype)

            if corr_matrix is None:
                corr_matrix = tf.eye(dim, dim, dtype=self._dtype)
            self._rho = tf.convert_to_tensor(corr_matrix,
                                             dtype=dtype,
                                             name='rho')
            self._sqrt_rho = tf.linalg.cholesky(self._rho)

            # Volatility function
            def _vol_fn(t, state):
                """Volatility function of Gaussian-HJM."""
                del state
                volatility = self._volatility(tf.expand_dims(
                    t, -1))  # shape=(dim, 1)

                return self._sqrt_rho * volatility

            # Drift function
            def _drift_fn(t, state):
                """Drift function of Gaussian-HJM."""
                x = state
                # shape = [self._factors, self._factors]
                y = self.state_y(tf.expand_dims(t, axis=-1))[..., 0]
                drift = tf.math.reduce_sum(y,
                                           axis=-1) - self._mean_reversion * x
                return drift

            self._exact_discretization_setup(dim)
            super(quasi_gaussian_hjm.QuasiGaussianHJM,
                  self).__init__(dim, _drift_fn, _vol_fn, self._dtype,
                                 self._name)
Beispiel #27
0
    def loop_tree_doubling(self, step_size, momentum_state_memory,
                           current_step_meta_info, iter_, initial_step_state,
                           initial_step_metastate):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            batch_shape = prefer_static.shape(
                current_step_meta_info.init_energy)
            direction = tf.cast(tf.random.uniform(shape=batch_shape,
                                                  minval=0,
                                                  maxval=2,
                                                  dtype=tf.int32,
                                                  seed=self._seed_stream()),
                                dtype=tf.bool)

            tree_start_states = tf.nest.map_structure(
                lambda v: tf.where(  # pylint: disable=g-long-lambda
                    _rightmost_expand_to_rank(
                        direction, prefer_static.rank(v[1])), v[1], v[0]),
                initial_step_state)

            directions_expanded = [
                _rightmost_expand_to_rank(direction, prefer_static.rank(state))
                for state in tree_start_states.state
            ]

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(d, ss, -ss)
                    for d, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state, tree_final_states, final_not_divergence,
                continue_tree_final, energy_diff_tree_sum,
                momentum_tree_cumsum, leapfrogs_taken
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                current_step_meta_info,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory)

            last_candidate_state = initial_step_metastate.candidate_state
            tree_weight = candidate_tree_state.weight
            if MULTINOMIAL_SAMPLE:
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-tf.random.uniform(shape=batch_shape,
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(choose_new_state,
                                                  prefer_static.rank(s0)), s0,
                        s1) for s0, s1 in zip(candidate_tree_state.state,
                                              last_candidate_state.state)
                ],
                target=tf.where(
                    _rightmost_expand_to_rank(
                        choose_new_state,
                        prefer_static.rank(candidate_tree_state.target)),
                    candidate_tree_state.target, last_candidate_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(choose_new_state,
                                                  prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            candidate_tree_state.target_grad_parts,
                            last_candidate_state.target_grad_parts)
                ],
                energy=tf.where(
                    _rightmost_expand_to_rank(
                        choose_new_state,
                        prefer_static.rank(candidate_tree_state.target)),
                    candidate_tree_state.energy, last_candidate_state.energy),
                weight=weight_sum)

            # Update left right information of the trajectory, and check trajectory
            # level U turn
            tree_otherend_states = tf.nest.map_structure(
                lambda v: tf.where(  # pylint: disable=g-long-lambda
                    _rightmost_expand_to_rank(
                        direction, prefer_static.rank(v[1])), v[0], v[1]),
                initial_step_state)

            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    tf.stack(
                        [  # pylint: disable=g-complex-comprehension
                            tf.where(
                                _rightmost_expand_to_rank(
                                    direction, prefer_static.rank(l)), r, l),
                            tf.where(
                                _rightmost_expand_to_rank(
                                    direction, prefer_static.rank(l)), l, r),
                        ],
                        axis=0)
                    for l, r in zip(tf.nest.flatten(tree_final_states),
                                    tf.nest.flatten(tree_otherend_states))
                ])

            if GENERALIZED_UTURN:
                state_diff = momentum_tree_cumsum
            else:
                state_diff = [s[1] - s[0] for s in new_step_state.state]

            no_u_turns_trajectory = has_not_u_turn(
                state_diff, [m[0] for m in new_step_state.momentum],
                [m[1] for m in new_step_state.momentum],
                log_prob_rank=len(batch_shape))

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                energy_diff_sum=(energy_diff_tree_sum +
                                 initial_step_metastate.energy_diff_sum),
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, new_step_state, new_step_metastate
Beispiel #28
0
    def sample_paths(self,
                     times: types.RealTensor,
                     num_samples: types.IntTensor,
                     time_step: types.RealTensor = None,
                     num_time_steps: types.IntTensor = None,
                     random_type: random.RandomType = None,
                     seed: types.IntTensor = None,
                     skip: types.IntTensor = 0,
                     name: str = None) -> types.RealTensor:
        """Returns a sample of short rate paths from the HJM process.

    Uses Euler sampling for simulating the short rate paths.

    Args:
      times: A real positive `Tensor` of shape `(num_times,)`. The times at
        which the path points are to be evaluated.
      num_samples: Positive scalar `int32` `Tensor`. The number of paths to
        draw.
      time_step: Scalar real `Tensor`. Maximal distance between time grid points
        in Euler scheme. Used only when Euler scheme is applied.
        Default value: `None`.
      num_time_steps: An optional Scalar integer `Tensor` - a total number of
        time steps performed by the algorithm. The maximal distance between
        points in grid is bounded by
        `times[-1] / (num_time_steps - times.shape[0])`.
        Either this or `time_step` should be supplied.
        Default value: `None`.
      random_type: Enum value of `RandomType`. The type of (quasi)-random
        number generator to use to generate the paths.
        Default value: `None` which maps to the standard pseudo-random numbers.
      seed: Seed for the random number generator. The seed is
        only relevant if `random_type` is one of
        `[STATELESS, PSEUDO, HALTON_RANDOMIZED, PSEUDO_ANTITHETIC,
          STATELESS_ANTITHETIC]`. For `PSEUDO`, `PSEUDO_ANTITHETIC` and
        `HALTON_RANDOMIZED` the seed should be an Python integer. For
        `STATELESS` and  `STATELESS_ANTITHETIC `must be supplied as an integer
        `Tensor` of shape `[2]`.
        Default value: `None` which means no seed is set.
      skip: `int32` 0-d `Tensor`. The number of initial points of the Sobol or
        Halton sequence to skip. Used only when `random_type` is 'SOBOL',
        'HALTON', or 'HALTON_RANDOMIZED', otherwise ignored.
        Default value: `0`.
      name: Python string. The name to give this op.
        Default value: `sample_paths`.

    Returns:
      A tuple containing four elements.

      * The first element is a `Tensor` of
      shape `[num_samples, num_times]` containing the simulated short rate
      paths.
      * The second element is a `Tensor` of shape
      `[num_samples, num_times]` containing the simulated discount factor
      paths.
      * The third element is a `Tensor` of shape
      `[num_samples, num_times, dim]` conating the simulated values of the
      state variable `x`
      * The fourth element is a `Tensor` of shape
      `[num_samples, num_times, dim^2]` conating the simulated values of the
      state variable `y`.

    Raises:
      ValueError:
        (a) If `times` has rank different from `1`.
        (b) If Euler scheme is used by times is not supplied.
    """
        name = name or self._name + '_sample_path'
        with tf.name_scope(name):
            times = tf.convert_to_tensor(times, self._dtype)
            if times.shape.rank != 1:
                raise ValueError('`times` should be a rank 1 Tensor. '
                                 'Rank is {} instead.'.format(
                                     times.shape.rank))
            return self._sample_paths(times, time_step, num_time_steps,
                                      num_samples, random_type, skip, seed)
Beispiel #29
0
    def _loop_build_sub_tree(self, directions, integrator,
                             current_step_meta_info, iter_,
                             energy_diff_sum_previous,
                             momentum_cumsum_previous, leapfrogs_taken,
                             prev_tree_state, candidate_tree_state,
                             continue_tree_previous, not_divergent_previous,
                             momentum_state_memory):
        """Base case in tree doubling."""
        with tf.name_scope('loop_build_sub_tree'):
            # Take one leapfrog step in the direction v and check divergence
            [
                next_momentum_parts, next_state_parts, next_target,
                next_target_grad_parts
            ] = integrator(prev_tree_state.momentum, prev_tree_state.state,
                           prev_tree_state.target,
                           prev_tree_state.target_grad_parts)

            next_tree_state = TreeDoublingState(
                momentum=next_momentum_parts,
                state=next_state_parts,
                target=next_target,
                target_grad_parts=next_target_grad_parts)
            momentum_cumsum = [
                p0 + p1 for p0, p1 in zip(momentum_cumsum_previous,
                                          next_momentum_parts)
            ]
            # If the tree have not yet terminated previously, we count this leapfrog.
            leapfrogs_taken = tf.where(continue_tree_previous,
                                       leapfrogs_taken + 1, leapfrogs_taken)

            write_instruction = current_step_meta_info.write_instruction
            read_instruction = current_step_meta_info.read_instruction
            init_energy = current_step_meta_info.init_energy

            # Save state and momentum at odd step, check U turn at even step.
            # Note that here we also write to a Placeholder at even step
            write_index = tf.where(tf.equal(iter_ % 2, 0),
                                   write_instruction.gather([iter_ // 2]),
                                   self.max_tree_depth)

            if GENERALIZED_UTURN:
                state_to_write = momentum_cumsum
            else:
                state_to_write = next_state_parts

            momentum_state_memory = MomentumStateSwap(
                momentum_swap=[
                    tf.tensor_scatter_nd_update(old, [write_index], [new])
                    for old, new in zip(momentum_state_memory.momentum_swap,
                                        next_momentum_parts)
                ],
                state_swap=[
                    tf.tensor_scatter_nd_update(old, [write_index], [new])
                    for old, new in zip(momentum_state_memory.state_swap,
                                        state_to_write)
                ])
            batch_shape = prefer_static.shape(next_target)
            has_not_u_turn_at_even_step = tf.ones(batch_shape, dtype=tf.bool)

            read_index = read_instruction.gather([iter_ // 2])[0]
            no_u_turns_within_tree = tf.cond(
                tf.equal(iter_ % 2, 0),
                lambda: has_not_u_turn_at_even_step,
                lambda: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
                    read_index,
                    directions,
                    momentum_state_memory,
                    next_momentum_parts,
                    state_to_write,
                    has_not_u_turn_at_even_step,
                    log_prob_rank=prefer_static.rank(next_target)))

            energy = compute_hamiltonian(next_target, next_momentum_parts)
            current_energy = tf.where(tf.math.is_nan(energy),
                                      tf.constant(-np.inf, dtype=energy.dtype),
                                      energy)
            energy_diff = current_energy - init_energy

            if MULTINOMIAL_SAMPLE:
                not_divergent = -energy_diff < self.max_energy_diff
                weight_sum = log_add_exp(candidate_tree_state.weight,
                                         energy_diff)
                log_accept_thresh = energy_diff - weight_sum
            else:
                log_slice_sample = current_step_meta_info.log_slice_sample
                not_divergent = log_slice_sample - energy_diff < self.max_energy_diff
                # Uniform sampling on the trajectory within the subtree across valid
                # samples.
                is_valid = log_slice_sample <= energy_diff
                weight_sum = tf.where(is_valid,
                                      candidate_tree_state.weight + 1,
                                      candidate_tree_state.weight)
                log_accept_thresh = tf.where(
                    is_valid,
                    -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)),
                    tf.constant(-np.inf, dtype=tf.float32))
            u = tf.math.log1p(-tf.random.uniform(shape=batch_shape,
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            next_candidate_tree_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(is_sample_accepted,
                                                  prefer_static.rank(s0)), s0,
                        s1) for s0, s1 in zip(next_state_parts,
                                              candidate_tree_state.state)
                ],
                target=tf.where(
                    _rightmost_expand_to_rank(is_sample_accepted,
                                              prefer_static.rank(next_target)),
                    next_target, candidate_tree_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(is_sample_accepted,
                                                  prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            next_target_grad_parts,
                            candidate_tree_state.target_grad_parts)
                ],
                energy=tf.where(
                    _rightmost_expand_to_rank(is_sample_accepted,
                                              prefer_static.rank(next_target)),
                    current_energy, init_energy),
                weight=weight_sum)

            continue_tree = not_divergent & continue_tree_previous
            continue_tree_next = no_u_turns_within_tree & continue_tree

            not_divergent_tokeep = tf.where(
                continue_tree_previous, not_divergent,
                tf.ones(batch_shape, dtype=tf.bool))

            # min(1., exp(energy_diff)).
            exp_energy_diff = tf.clip_by_value(tf.exp(energy_diff), 0., 1.)
            energy_diff_sum = tf.where(
                continue_tree, energy_diff_sum_previous + exp_energy_diff,
                energy_diff_sum_previous)

            return (
                iter_ + 1,
                energy_diff_sum,
                momentum_cumsum,
                leapfrogs_taken,
                next_tree_state,
                next_candidate_tree_state,
                continue_tree_next,
                not_divergent_previous & not_divergent_tokeep,
                momentum_state_memory,
            )
Beispiel #30
0
  def __init__(self,
               df,
               kernel,
               index_points,
               mean_fn=None,
               jitter=1e-6,
               validate_args=False,
               allow_nan_stats=False,
               name='StudentTProcess'):
    """Instantiate a StudentTProcess Distribution.

    Args:
      df: Positive Floating-point `Tensor` representing the degrees of freedom.
        Must be greater than 2.
      kernel: `PositiveSemidefiniteKernel`-like instance representing the
        TP's covariance function.
      index_points: `float` `Tensor` representing finite (batch of) vector(s) of
        points in the index set over which the TP is defined. Shape has the form
        `[b1, ..., bB, e, f1, ..., fF]` where `F` is the number of feature
        dimensions and must equal `kernel.feature_ndims` and `e` is the number
        (size) of index points in each batch. Ultimately this distribution
        corresponds to a `e`-dimensional multivariate Student's T. The batch
        shape must be broadcastable with `kernel.batch_shape` and any batch dims
        yielded by `mean_fn`.
      mean_fn: Python `callable` that acts on `index_points` to produce a (batch
        of) vector(s) of mean values at `index_points`. Takes a `Tensor` of
        shape `[b1, ..., bB, f1, ..., fF]` and returns a `Tensor` whose shape is
        broadcastable with `[b1, ..., bB]`. Default value: `None` implies
        constant zero function.
      jitter: `float` scalar `Tensor` added to the diagonal of the covariance
        matrix to ensure positive definiteness of the covariance matrix.
        Default value: `1e-6`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
        Default value: `False`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "StudentTProcess".

    Raises:
      ValueError: if `mean_fn` is not `None` and is not callable.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      dtype = dtype_util.common_dtype(
          [df, index_points, jitter], tf.float32)
      df = tf.convert_to_tensor(df, dtype=dtype, name='df')
      index_points = tf.convert_to_tensor(
          index_points, dtype=dtype, name='index_points')
      jitter = tf.convert_to_tensor(jitter, dtype=dtype, name='jitter')

      with tf.control_dependencies([
          assert_util.assert_greater(
              df, tf.cast(2., df.dtype), message='`df` must be greater than 2.')
      ] if validate_args else []):
        self._df = tf.identity(df)

      self._kernel = kernel
      self._index_points = index_points
      # Default to a constant zero function, borrowing the dtype from
      # index_points to ensure consistency.
      if mean_fn is None:
        mean_fn = lambda x: tf.zeros([1], dtype=dtype)
      else:
        if not callable(mean_fn):
          raise ValueError('`mean_fn` must be a Python callable')
      self._mean_fn = mean_fn
      self._jitter = jitter

      with tf.name_scope('init'):
        kernel_matrix = _add_diagonal_shift(
            kernel.matrix(self.index_points, self.index_points),
            jitter)
        self._covariance_matrix = kernel_matrix

        scale = tf.linalg.LinearOperatorLowerTriangular(
            tf.linalg.cholesky(
                ((self.df - 2) / self.df)[..., tf.newaxis, tf.newaxis] *
                kernel_matrix),
            is_non_singular=True,
            name='StudentTProcessScaleLinearOperator')

        super(StudentTProcess, self).__init__(
            df=df,
            loc=mean_fn(index_points),
            scale=scale,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name=name)
        self._parameters = parameters
        self._graph_parents = [index_points, jitter]