Exemple #1
0
def _piecewise_constant_function(x,
                                 jump_locations,
                                 values,
                                 batch_rank,
                                 side='left'):
    """Computes value of the piecewise constant function."""
    # Initializer already verified that `jump_locations` and `values` have the
    # same shape
    batch_shape = jump_locations.shape.as_list()[:-1]
    # Check that the batch shape of `x` is the same as of `jump_locations` and
    # `values`
    batch_shape_x = x.shape.as_list()[:batch_rank]
    if batch_shape_x != batch_shape:
        raise ValueError('Batch shape of `x` is {1} but should be {0}'.format(
            batch_shape, batch_shape_x))
    if x.shape.as_list()[:batch_rank]:
        no_batch_shape = False
    else:
        no_batch_shape = True
        x = tf.expand_dims(x, 0)
    # Expand batch size to one if there is no batch shape
    if not batch_shape:
        jump_locations = tf.expand_dims(jump_locations, 0)
        values = tf.expand_dims(values, 0)
    indices = tf.searchsorted(jump_locations, x, side=side)
    index_matrix = _prepare_index_matrix(indices.shape.as_list()[:-1],
                                         indices.shape.as_list()[-1],
                                         indices.dtype)
    indices_nd = tf.concat([index_matrix, tf.expand_dims(indices, -1)], -1)
    res = tf.gather_nd(values, indices_nd)
    if no_batch_shape:
        return tf.squeeze(res, 0)
    else:
        return res
 def _inverse(self, y):
   map_values = tf.convert_to_tensor(self.map_values)
   flat_y = tf.reshape(y, shape=[-1])
   # Search for the indices of map_values that are closest to flat_y.
   # Since map_values is strictly increasing, the closest is either the
   # first one that is strictly greater than flat_y, or the one before it.
   upper_candidates = tf.minimum(
       tf.size(map_values) - 1,
       tf.searchsorted(map_values, values=flat_y, side='right'))
   lower_candidates = tf.maximum(0, upper_candidates - 1)
   candidates = tf.stack([lower_candidates, upper_candidates], axis=-1)
   lower_cand_diff = tf.abs(flat_y - self._forward(lower_candidates))
   upper_cand_diff = tf.abs(flat_y - self._forward(upper_candidates))
   if self.validate_args:
     with tf.control_dependencies([
         assert_util.assert_near(
             tf.minimum(lower_cand_diff, upper_cand_diff),
             0,
             message='inverse value not found')
     ]):
       candidates = tf.identity(candidates)
   candidate_selector = tf.stack([
       tf.range(tf.size(flat_y), dtype=tf.int32),
       tf.argmin([lower_cand_diff, upper_cand_diff], output_type=tf.int32)
   ],
                                 axis=-1)
   return tf.reshape(
       tf.gather_nd(candidates, candidate_selector), shape=y.shape)
  def _conditional_variance_x(self, t, mr_t, sigma_t):
    """Computes the variance of x(t), see [1], Eq. 10.41."""
    # Shape [dim, num_times]
    t = tf.broadcast_to(t, tf.concat([[self._dim], tf.shape(t)], axis=-1))
    var_x_between_vol_knots = self._variance_int(self._padded_knots,
                                                 self._jump_locations,
                                                 self._jump_values_vol,
                                                 self._jump_values_mr)
    varx_at_vol_knots = tf.concat(
        [self._zero_padding,
         utils.cumsum_using_matvec(var_x_between_vol_knots)],
        axis=1)

    time_index = tf.searchsorted(self._jump_locations, t)
    vn = tf.concat(
        [self._zero_padding,
         self._jump_locations], axis=1)

    var_x_t = self._variance_int(
        tf.gather(vn, time_index, batch_dims=1), t, sigma_t, mr_t)
    var_x_t = var_x_t + tf.gather(varx_at_vol_knots, time_index, batch_dims=1)

    var_x_t = (var_x_t[:, 1:] - var_x_t[:, :-1]) * tf.math.exp(
        -2 * tf.broadcast_to(mr_t, tf.shape(t))[:, 1:] * t[:, 1:])
    return var_x_t
  def _conditional_mean_x(self, t, mr_t, sigma_t):
    """Computes the drift term in [1], Eq. 10.39."""
    # Shape [dim, num_times]
    t = tf.broadcast_to(t, tf.concat([[self._dim], tf.shape(t)], axis=-1))
    time_index = tf.searchsorted(self._jump_locations, t)
    vn = tf.concat([self._zero_padding, self._jump_locations], axis=1)
    y_between_vol_knots = self._y_integral(self._padded_knots,
                                           self._jump_locations,
                                           self._jump_values_vol,
                                           self._jump_values_mr)

    y_at_vol_knots = tf.concat(
        [self._zero_padding,
         utils.cumsum_using_matvec(y_between_vol_knots)], axis=1)

    ex_between_vol_knots = self._ex_integral(self._padded_knots,
                                             self._jump_locations,
                                             self._jump_values_vol,
                                             self._jump_values_mr,
                                             y_at_vol_knots[:, :-1])

    ex_at_vol_knots = tf.concat(
        [self._zero_padding,
         utils.cumsum_using_matvec(ex_between_vol_knots)], axis=1)

    c = tf.gather(y_at_vol_knots, time_index, batch_dims=1)
    exp_x_t = self._ex_integral(
        tf.gather(vn, time_index, batch_dims=1), t, sigma_t, mr_t, c)
    exp_x_t = exp_x_t + tf.gather(ex_at_vol_knots, time_index, batch_dims=1)
    exp_x_t = (exp_x_t[:, 1:] - exp_x_t[:, :-1]) * tf.math.exp(
        -tf.broadcast_to(mr_t, tf.shape(t))[:, 1:] * t[:, 1:])
    return exp_x_t
        def _get_swap_payoff(payoff_time):
            broadcasted_exercise_times = tf.broadcast_to(
                payoff_time, shape_to_broadcast[1:])

            # Zero-coupon bond curve
            zcb_curve = model.discount_bond_price(
                tf.transpose(
                    tf.reshape(state_x,
                               [dim, num_grid_points * num_maturities])),
                tf.reshape(broadcasted_exercise_times, [-1]),
                tf.reshape(broadcasted_maturities, [-1]))
            zcb_curve = tf.reshape(zcb_curve,
                                   [num_grid_points, num_maturities])

            maturities_index = tf.searchsorted(unique_maturities,
                                               tf.reshape(maturities, [-1]))

            zcb_curve = tf.gather(zcb_curve, maturities_index, axis=-1)
            # zcb_curve.shape = [num_grid_points] + [maturities_shape]
            zcb_curve = tf.reshape(
                zcb_curve,
                tf.concat([[num_grid_points], maturities_shape], axis=0))

            # Shape after reduce_sum =
            # (num_grid_points, batch_shape)
            fixed_leg = tf.math.reduce_sum(
                fixed_leg_coupon * fixed_leg_daycount_fractions * zcb_curve,
                axis=-1)
            float_leg = 1.0 - zcb_curve[..., -1]
            payoff_swap = float_leg - fixed_leg
            payoff_swap = tf.where(is_payer_swaption, payoff_swap,
                                   -payoff_swap)
            return tf.reshape(
                tf.transpose(payoff_swap),
                tf.concat([batch_shape, meshgrid_shape[1:]], axis=0))
Exemple #6
0
 def _get_index(t, tensor_to_search):
     t = tf.expand_dims(t, axis=-1)
     index = tf.searchsorted(tensor_to_search, t - _PDE_TIME_GRID_TOL,
                             'right')
     y = tf.gather(tensor_to_search, index)
     return tf.where(
         tf.math.abs(t - y) < _PDE_TIME_GRID_TOL, index, -1)[0]
    def state_y(self, t):
        """Computes the state variable `y(t)` for tha Gaussian HJM Model.

    For Gaussian HJM model, the state parameter y(t), can be analytically
    computed as follows:

    y_ij(t) = exp(-k_i * t) * exp(-k_j * t) * (
              int_0^t rho_ij * sigma_i(u) * sigma_j(u) * du)

    Args:
      t: A rank 1 real `Tensor` of shape `[num_times]` specifying the time `t`.

    Returns:
      A real `Tensor` of shape [self._factors, self._factors, num_times]
      containing the computed y_ij(t).
    """
        t = tf.convert_to_tensor(t, dtype=self._dtype)
        t_shape = tf.shape(t)
        t = tf.broadcast_to(t, tf.concat([[self._dim], t_shape], axis=0))
        time_index = tf.searchsorted(self._jump_locations, t)
        # create a matrix k2(i,j) = k(i) + k(j)
        mr2 = tf.expand_dims(self._mean_reversion, axis=-1)
        # Add a dimension corresponding to `num_times`
        mr2 = tf.expand_dims(mr2 + tf.transpose(mr2), axis=-1)

        def _integrate_volatility_squared(vol, l_limit, u_limit):
            # create sigma2_ij = sigma_i * sigma_j
            vol = tf.expand_dims(vol, axis=-2)
            vol_squared = tf.expand_dims(
                self._rho, axis=-1) * (vol * tf.transpose(vol, perm=[1, 0, 2]))
            return vol_squared / mr2 * (tf.math.exp(mr2 * u_limit) -
                                        tf.math.exp(mr2 * l_limit))

        is_constant_vol = tf.math.equal(tf.shape(self._jump_values_vol)[-1], 0)
        v_squared_between_vol_knots = tf.cond(
            is_constant_vol,
            lambda: tf.zeros(shape=(self._dim, self._dim, 0),
                             dtype=self._dtype),
            lambda: _integrate_volatility_squared(  # pylint: disable=g-long-lambda
                self._jump_values_vol, self._padded_knots, self._jump_locations
            ))
        v_squared_at_vol_knots = tf.concat([
            tf.zeros((self._dim, self._dim, 1), dtype=self._dtype),
            utils.cumsum_using_matvec(v_squared_between_vol_knots)
        ],
                                           axis=-1)

        vn = tf.concat([self._zero_padding, self._jump_locations], axis=1)

        v_squared_t = _integrate_volatility_squared(
            self._volatility(t), tf.gather(vn, time_index, batch_dims=1), t)
        v_squared_t += tf.gather(v_squared_at_vol_knots,
                                 time_index,
                                 batch_dims=-1)

        return tf.math.exp(-mr2 * t) * v_squared_t
def _bond_option_variance(model, option_expiry, bond_maturity, dim):
    """Computes black equivalent variance for bond options.

  Black equivalent variance is definied as the variance to use in the Black
  formula to obtain the model implied price of European bond options.

  Args:
    model: An instance of `VectorHullWhiteModel`.
    option_expiry: A rank 1 `Tensor` of real dtype specifying the time to
      expiry of each option.
    bond_maturity: A rank 1 `Tensor` of real dtype specifying the time to
      maturity of underlying zero coupon bonds.
    dim: Dimensionality of the Hull-White process.

  Returns:
    A rank 1 `Tensor` of same dtype and shape as the inputs with computed
    Black-equivalent variance for the underlying options.
  """
    # pylint: disable=protected-access
    if model._sample_with_generic:
        raise ValueError('The paramerization of `mean_reversion` and/or '
                         '`volatility` does not support analytic computation '
                         'of bond option variance.')
    mean_reversion = model.mean_reversion(option_expiry)
    volatility = model.volatility(option_expiry)

    option_expiry = tf.repeat(tf.expand_dims(option_expiry, axis=0),
                              dim,
                              axis=0)
    bond_maturity = tf.repeat(tf.expand_dims(bond_maturity, axis=0),
                              dim,
                              axis=0)

    var_between_vol_knots = model._variance_int(model._padded_knots,
                                                model._jump_locations,
                                                model._jump_values_vol,
                                                model._jump_values_mr)
    varx_at_vol_knots = tf.concat([
        model._zero_padding,
        vector_hull_white._cumsum_using_matvec(var_between_vol_knots)
    ],
                                  axis=1)

    time_index = tf.searchsorted(model._jump_locations, option_expiry)
    vn = tf.concat([model._zero_padding, model._jump_locations], axis=1)

    var_expiry = model._variance_int(tf.gather(vn, time_index, batch_dims=1),
                                     option_expiry, volatility, mean_reversion)
    var_expiry = var_expiry + tf.gather(
        varx_at_vol_knots, time_index, batch_dims=1)
    var_expiry = var_expiry * (
        tf.math.exp(-mean_reversion * option_expiry) -
        tf.math.exp(-mean_reversion * bond_maturity))**2 / mean_reversion**2
    # gpylint: enable=protected-access
    return var_expiry
Exemple #9
0
def _grid_from_num_times(*, times, time_step, num_time_steps):
    """Creates a time grid for the requeste number of time steps."""
    # Build a uniform grid for the timestep of size
    # max(0, num_time_steps - tf.shape(times)[0])
    uniform_grid = tf.linspace(time_step, times[-1] - time_step,
                               tf.nn.relu(num_time_steps - tf.shape(times)[0]))
    grid = tf.sort(tf.concat([uniform_grid, times], 0))
    # Add zero to the time grid
    all_times = tf.concat([[0], grid], 0)
    time_indices = tf.searchsorted(all_times, times, out_type=tf.int32)
    return all_times, time_indices
Exemple #10
0
 def _cdf(self, x):
   x = tf.convert_to_tensor(x, name='x')
   flat_x = tf.reshape(x, shape=[-1])
   upper_bound = tf.searchsorted(self.outcomes, values=flat_x, side='right')
   values_at_ub = tf.gather(
       self.outcomes,
       indices=tf.minimum(upper_bound,
                          dist_util.prefer_static_shape(self.outcomes)[-1] -
                          1))
   should_use_upper_bound = self._is_equal_or_close(flat_x, values_at_ub)
   indices = tf.where(should_use_upper_bound, upper_bound, upper_bound - 1)
   return self._categorical.cdf(
       tf.reshape(indices, shape=dist_util.prefer_static_shape(x)))
Exemple #11
0
def prepare_grid(*, times, time_step, dtype):
    """Prepares grid of times for path generation.

  Args:
    times:  Rank 1 `Tensor` of increasing positive real values. The times at
      which the path points are to be evaluated.
    time_step: Rank 0 real `Tensor`. Maximal distance between points in
      resulting grid.
    dtype: `tf.Dtype` of the input and output `Tensor`s.

  Returns:
    Tuple `(all_times, mask, time_points)`.
    `all_times` is a 1-D real `Tensor` containing all points from 'times` and
    the uniform grid of points between `[0, times[-1]]` with grid size equal to
    `time_step`. The `Tensor` is sorted in ascending order and may contain
    duplicates.
    `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing
    which elements of 'all_times' correspond to THE values from `times`.
    Guarantees that times[0]=0 and mask[0]=False.
    `time_indices`. An integer `Tensor` of the same shape as `times` indicating
    `times` indices in `all_times`.
  """
    grid = tf.range(0.0, times[-1], time_step, dtype=dtype)
    all_times = tf.concat([times, grid], axis=0)
    # Remove duplicate points
    # all_times = tf.unique(all_times).y
    # Sort sequence. Identify the time indices of interest
    # TODO(b/169400743): use tf.sort instead of argsort and casting when XLA
    # float64 support is extended for tf.sort
    args = tf.argsort(tf.cast(all_times, dtype=tf.float32))
    all_times = tf.gather(all_times, args)
    # Remove duplicate points
    duplicate_tol = 1e-10 if dtype == tf.float64 else 1e-6
    dt = all_times[1:] - all_times[:-1]
    dt = tf.concat([[1.0], dt], axis=-1)
    duplicate_mask = tf.math.greater(dt, duplicate_tol)
    all_times = tf.boolean_mask(all_times, duplicate_mask)

    time_indices = tf.searchsorted(all_times, times, out_type=tf.int32)
    # Create a boolean mask to identify the iterations that have to be recorded.
    mask_sparse = tf.sparse.SparseTensor(
        indices=tf.expand_dims(tf.cast(time_indices, dtype=tf.int64), axis=1),
        values=tf.fill(tf.shape(times), True),
        dense_shape=tf.shape(all_times, out_type=tf.int64))
    mask = tf.sparse.to_dense(mask_sparse)
    # all_times = tf.concat([[0.0], all_times], axis=0)
    # mask = tf.concat([[False], mask], axis=0)
    # time_indices = time_indices + 1
    return all_times, mask, time_indices
Exemple #12
0
def _get_indices_and_values(x, index_matrix, jump_locations, values, side,
                            batch_rank):
  """Computes values and jump locations of the piecewise constant function.

  Given `jump_locations` and the `values` on the corresponding segments of the
  piecewise constant function, the function identifies the nearest jump to `x`
  from the right or left (which is determined by the `side` argument) and the
  corresponding value of the piecewise constant function at `x`

  Args:
    x: A real `Tensor` of shape `batch_shape + [num_points]`. Points at which
      the function has to be evaluated.
    index_matrix: An `int32` `Tensor` of shape
      `batch_shape + [num_points] + [len(batch_shape)]` such that if
      `batch_shape = [i1, .., in]`, then for all `j1, ..., jn, l`,
      `index_matrix[j1,..,jn, l] = [j1, ..., jn]`.
    jump_locations: A `Tensor` of the same `dtype` as `x` and shape
      `batch_shape + [num_jump_points]`. The locations where the function
      changes its values. Note that the values are expected to be ordered
      along the last dimension.
    values: A `Tensor` of the same `dtype` as `x` and shape
      `batch_shape + [num_jump_points + 1]`. Defines `values[..., i]` on
      `jump_locations[..., i - 1], jump_locations[..., i]`.
    side: A Python string. Whether the function is left- or right- continuous.
      The corresponding values for side should be `left` and `right`.
    batch_rank: A Python scalar stating the batch rank of `x`.

  Returns:
    A tuple of three `Tensor` of the same `dtype` as `x` and shapes
    `batch_shape + [num_points] + event_shape`, `batch_shape + [num_points]`,
    and `batch_shape + [num_points] + [2 * len(batch_shape)]`. The `Tensor`s
    correspond to the values, jump locations at `x`, and the corresponding
    indices used to obtain jump locations via `tf.gather_nd`.
  """
  indices = tf.searchsorted(jump_locations, x, side=side)
  num_data_points = tf.shape(values)[batch_rank] - 2
  if side == 'right':
    indices_jump = indices - 1
    indices_jump = tf.maximum(indices_jump, 0)
  else:
    indices_jump = tf.minimum(indices, num_data_points)
  indices_nd = tf.concat(
      [index_matrix, tf.expand_dims(indices, -1)], -1)
  indices_jump_nd = tf.concat(
      [index_matrix, tf.expand_dims(indices_jump, -1)], -1)
  value = tf.gather_nd(values, indices_nd)
  jump_location = tf.gather_nd(jump_locations, indices_jump_nd)
  return value, jump_location, indices_jump_nd
Exemple #13
0
  def _compute_yt(self, t, mr_t, sigma_t):
    """Computes y(t) as described in [1], section 10.1.6.1."""
    t = tf.repeat(tf.expand_dims(t, axis=0), self._dim, axis=0)
    time_index = tf.searchsorted(self._jump_locations, t)
    y_between_vol_knots = self._y_integral(
        self._padded_knots, self._jump_locations, self._jump_values_vol,
        self._jump_values_mr)
    y_at_vol_knots = tf.concat(
        [self._zero_padding,
         _cumsum_using_matvec(y_between_vol_knots)], axis=1)

    vn = tf.concat(
        [self._zero_padding, self._jump_locations], axis=1)
    y_t = self._y_integral(
        tf.gather(vn, time_index, batch_dims=1), t, sigma_t, mr_t)
    y_t = y_t + tf.gather(y_at_vol_knots, time_index, batch_dims=1)
    return tf.math.exp(-2 * mr_t * t) * y_t
 def _cdf(self, x):
     x = tf.convert_to_tensor(x, name='x')
     flat_x = tf.reshape(x, shape=[-1])
     upper_bound = tf.searchsorted(self.outcomes,
                                   values=flat_x,
                                   side='right')
     values_at_ub = tf.gather(self.outcomes,
                              indices=tf.minimum(
                                  upper_bound,
                                  ps.shape(self.outcomes)[-1] - 1))
     should_use_upper_bound = self._is_equal_or_close(flat_x, values_at_ub)
     indices = tf.where(should_use_upper_bound, upper_bound,
                        upper_bound - 1)
     indices = tf.reshape(indices, shape=dist_util.prefer_static_shape(x))
     indices_non_negative = tf.where(tf.equal(indices, -1),
                                     tf.zeros([], indices.dtype), indices)
     cdf = self._categorical.cdf(indices_non_negative)
     return tf.where(tf.equal(indices, -1), tf.zeros([], cdf.dtype), cdf)
  def _compute_yt(self, t, mr_t, sigma_t):
    """Computes y(t) as described in [1], section 10.1.6.1."""
    # Shape [dim, num_times]
    t = tf.broadcast_to(t, tf.concat([[self._dim], tf.shape(t)], axis=-1))
    time_index = tf.searchsorted(self._jump_locations, t)
    y_between_vol_knots = self._y_integral(
        self._padded_knots, self._jump_locations, self._jump_values_vol,
        self._jump_values_mr)
    y_at_vol_knots = tf.concat(
        [self._zero_padding,
         utils.cumsum_using_matvec(y_between_vol_knots)], axis=1)

    vn = tf.concat(
        [self._zero_padding, self._jump_locations], axis=1)
    y_t = self._y_integral(
        tf.gather(vn, time_index, batch_dims=1), t, sigma_t, mr_t)
    y_t = y_t + tf.gather(y_at_vol_knots, time_index, batch_dims=1)
    return tf.math.exp(-2 * mr_t * t) * y_t
 def _log_prob(self, x):
     x = tf.convert_to_tensor(x, name='x')
     right_indices = tf.minimum(
         tf.size(self.outcomes) - 1,
         tf.reshape(
             tf.searchsorted(self.outcomes,
                             values=tf.reshape(x, shape=[-1]),
                             side='right'), ps.shape(x)))
     use_right_indices = self._is_equal_or_close(
         x, tf.gather(self.outcomes, indices=right_indices))
     left_indices = tf.maximum(0, right_indices - 1)
     use_left_indices = self._is_equal_or_close(
         x, tf.gather(self.outcomes, indices=left_indices))
     log_probs = self._categorical.log_prob(
         tf.where(use_left_indices, left_indices, right_indices))
     return tf.where(tf.logical_not(use_left_indices | use_right_indices),
                     dtype_util.as_numpy_dtype(log_probs.dtype)(-np.inf),
                     log_probs)
Exemple #17
0
def _grid_from_time_step(*, times, time_step, dtype):
    """Creates a time grid from an input time step."""
    grid = tf.range(0.0, times[-1], time_step, dtype=dtype)
    all_times = tf.concat([times, grid], axis=0)
    all_times = tf.sort(all_times)
    # Remove duplicate points
    duplicate_tol = 1e-10 if dtype == tf.float64 else 1e-6
    dt = all_times[1:] - all_times[:-1]
    dt = tf.concat([[1.0], dt], axis=-1)
    duplicate_mask = tf.math.greater(dt, duplicate_tol)
    all_times = tf.boolean_mask(all_times, duplicate_mask)
    time_indices = tf.searchsorted(all_times, times, out_type=tf.int32)
    # Move `time_indices` to the left, if the requested `times` are removed from
    # `all_times` during deduplication
    time_indices = tf.where(
        tf.gather(all_times, time_indices) - times > duplicate_tol,
        time_indices - 1, time_indices)
    return all_times, time_indices
Exemple #18
0
def _prepare_grid(*, times, time_step, dtype):
    """Prepares grid of times for path generation.

  Args:
    times:  Rank 1 `Tensor` of increasing positive real values. The times at
      which the path points are to be evaluated.
    time_step: Rank 0 real `Tensor`. Maximal distance between points in
      resulting grid.
    dtype: `tf.Dtype` of the input and output `Tensor`s.

  Returns:
    Tuple `(all_times, mask, time_points)`.
    `all_times` is a 1-D real `Tensor` containing all points from 'times` and
    the uniform grid of points between `[0, times[-1]]` with grid size equal to
    `time_step`. The `Tensor` is sorted in ascending order and may contain
    duplicates.
    `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing
    which elements of 'all_times' correspond to THE values from `times`.
    Guarantees that times[0]=0 and mask[0]=False.
    `time_indices`. An integer `Tensor` of the same shape as `times` indicating
    `times` indices in `all_times`.
  """
    grid = tf.range(0.0, times[-1], time_step, dtype=dtype)
    all_times = tf.concat([times, grid], axis=0)
    # Remove duplicate points
    all_times = tf.unique(all_times).y
    # Sort sequence. Identify the time indices of interest
    all_times = tf.sort(all_times)
    time_indices = tf.searchsorted(all_times, times, out_type=tf.int32)
    # Create a boolean mask to identify the iterations that have to be recorded.
    mask_sparse = tf.sparse.SparseTensor(indices=tf.expand_dims(tf.cast(
        time_indices, dtype=tf.int64),
                                                                axis=1),
                                         values=tf.fill(times.shape, True),
                                         dense_shape=all_times.shape)
    mask = tf.sparse.to_dense(mask_sparse)
    return all_times, mask, time_indices
Exemple #19
0
 def _log_prob(self, x):
     x = tf.convert_to_tensor(value=x, name='x')
     right_indices = tf.minimum(
         tf.size(input=self.outcomes) - 1,
         tf.reshape(
             tf.searchsorted(self.outcomes,
                             values=tf.reshape(x, shape=[-1]),
                             side='right'),
             dist_util.prefer_static_shape(x)))
     use_right_indices = self._is_equal_or_close(
         x, tf.gather(self.outcomes, indices=right_indices))
     left_indices = tf.maximum(0, right_indices - 1)
     use_left_indices = self._is_equal_or_close(
         x, tf.gather(self.outcomes, indices=left_indices))
     log_probs = self._categorical.log_prob(
         tf1.where(use_left_indices, left_indices, right_indices))
     should_be_neg_inf = tf.broadcast_to(
         tf.logical_not(use_left_indices | use_right_indices),
         shape=dist_util.prefer_static_shape(log_probs))
     return tf1.where(
         should_be_neg_inf,
         tf.fill(dist_util.prefer_static_shape(should_be_neg_inf),
                 dtype_util.as_numpy_dtype(log_probs.dtype)(-np.inf)),
         log_probs)
Exemple #20
0
def _prepare_grid(*, times, time_step, dtype):
    """Prepares grid of times for path generation.

  Args:
    times:  Rank 1 `Tensor` of increasing positive real values. The times at
      which the path points are to be evaluated.
    time_step: Rank 0 real `Tensor`. Maximal distance between points in
      resulting grid.
    dtype: `tf.Dtype` of the input and output `Tensor`s.

  Returns:
    Tuple `(all_times, mask, time_points)`.
    `all_times` is a 1-D real `Tensor` containing all points from 'times` and
    the uniform grid of points between `[0, times[-1]]` with grid size equal to
    `time_step`. The `Tensor` is sorted in ascending order and may contain
    duplicates.
    `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing
    which elements of 'all_times' correspond to THE values from `times`.
    Guarantees that times[0]=0 and mask[0]=False.
    `time_indices`. An integer `Tensor` of the same shape as `times` indicating
    `times` indices in `all_times`.
  """
    grid = tf.range(0.0, times[-1], time_step, dtype=dtype)
    all_times = tf.concat([grid, times], axis=0)
    mask = tf.concat([
        tf.zeros_like(grid, dtype=tf.bool),
        tf.ones_like(times, dtype=tf.bool)
    ],
                     axis=0)
    perm = tf.argsort(all_times, stable=True)
    all_times = tf.gather(all_times, perm)
    # Remove duplicate points
    all_times = tf.unique(all_times).y
    time_indices = tf.searchsorted(all_times, times)
    mask = tf.gather(mask, perm)
    return all_times, mask, time_indices
Exemple #21
0
def interpolate(x,
                x_data,
                y_data,
                left_slope=None,
                right_slope=None,
                validate_args=False,
                optimize_for_tpu=False,
                dtype=None,
                name=None):
    """Performs linear interpolation for supplied points.

  Given a set of knots whose x- and y- coordinates are in `x_data` and `y_data`,
  this function returns y-values for x-coordinates in `x` via piecewise
  linear interpolation.

  `x_data` must be non decreasing, but `y_data` don't need to be because we do
  not require the function approximated by these knots to be monotonic.

  #### Examples

  ```python
  x = [-10, -1, 1, 3, 6, 7, 8, 15, 18, 25, 30, 35]
  x_data = [-1, 2, 6, 8, 18, 30.0]
  y_data = [10, -1, -5, 7, 9, 20]

  result = linear_interpolation(x, x_data, y_data)
  # [ 10, 10, 2.66666667, -2, -5, 1, 7, 8.4, 9, 15.41666667, 20, 20]
  ```

  Args:
    x: x-coordinates for which we need to get interpolation. A N-D `Tensor` of
      real dtype. First N-1 dimensions represent batching dimensions.
    x_data: x coordinates. A N-D `Tensor` of real dtype. Should be sorted
      in non decreasing order. First N-1 dimensions represent batching
      dimensions.
    y_data: y coordinates. A N-D `Tensor` of real dtype. Should have the
      compatible shape as `x_data`. First N-1 dimensions represent batching
      dimensions.
    left_slope: The slope to use for extrapolation with x-coordinate smaller
      than the min `x_data`. It's a 0-D or N-D `Tensor`.
      Default value: `None`, which maps to `0.0` meaning constant extrapolation,
      i.e. extrapolated value will be the leftmost `y_data`.
    right_slope: The slope to use for extrapolation with x-coordinate greater
      than the max `x_data`. It's a 0-D or N-D `Tensor`.
      Default value: `None` which maps to `0.0` meaning constant extrapolation,
      i.e. extrapolated value will be the rightmost `y_data`.
    validate_args: Python `bool` that indicates whether the function performs
      the check if the shapes of `x_data` and `y_data` are equal and that the
      elements in `x_data` are non decreasing. If this value is set to `False`
      and the elements in `x_data` are not increasing, the result of linear
      interpolation may be wrong.
      Default value: `False`.
    optimize_for_tpu: A Python bool. If `True`, the algorithm uses one-hot
      encoding to lookup indices of `x_values` in `x_data`. This significantly
      improves performance of the algorithm on a TPU device but may slow down
      performance on the CPU.
      Default value: `False`.
    dtype: Optional tf.dtype for `x`, x_data`, `y_data`, `left_slope` and
      `right_slope`.
      Default value: `None` which means that the `dtype` inferred by TensorFlow
      is used.
    name: Python str. The name prefixed to the ops created by this function.
      Default value: `None` which maps to 'linear_interpolation'.

  Returns:
    A N-D `Tensor` of real dtype corresponding to the x-values in `x`.
  """
    name = name or "linear_interpolation"
    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, dtype=dtype, name="x")
        dtype = dtype or x.dtype
        x_data = tf.convert_to_tensor(x_data, dtype=dtype, name="x_data")
        y_data = tf.convert_to_tensor(y_data, dtype=dtype, name="y_data")
        # Try broadcast batch_shapes
        x, x_data = utils.broadcast_common_batch_shape(x, x_data)
        x, y_data = utils.broadcast_common_batch_shape(x, y_data)
        x_data, y_data = utils.broadcast_common_batch_shape(x_data, y_data)

        batch_shape = x.shape.as_list()[:-1]
        if not batch_shape:
            x = tf.expand_dims(x, 0)
            x_data = tf.expand_dims(x_data, 0)
            y_data = tf.expand_dims(y_data, 0)

        if left_slope is None:
            left_slope = tf.constant(0.0, dtype=x.dtype, name="left_slope")
        else:
            left_slope = tf.convert_to_tensor(left_slope,
                                              dtype=dtype,
                                              name="left_slope")
        if right_slope is None:
            right_slope = tf.constant(0.0, dtype=x.dtype, name="right_slope")
        else:
            right_slope = tf.convert_to_tensor(right_slope,
                                               dtype=dtype,
                                               name="right_slope")
        control_deps = []
        if validate_args:
            # Check that `x_data` elements is non-decreasing
            diffs = x_data[..., 1:] - x_data[..., :-1]
            assertion = tf.compat.v1.debugging.assert_greater_equal(
                diffs,
                tf.zeros_like(diffs),
                message="x_data is not sorted in non-decreasing order.")
            control_deps.append(assertion)
            # Check that the shapes of `x_data` and `y_data` are equal
            control_deps.append(
                tf.compat.v1.assert_equal(tf.shape(x_data), tf.shape(y_data)))

        with tf.control_dependencies(control_deps):
            # Get upper bound indices for `x`.
            upper_indices = tf.searchsorted(x_data,
                                            x,
                                            side="left",
                                            out_type=tf.int32)
            x_data_size = x_data.shape.as_list()[-1]
            at_min = tf.equal(upper_indices, 0)
            at_max = tf.equal(upper_indices, x_data_size)
            # Create tensors in order to be used by `tf.where`.
            # `values_min` are extrapolated values for x-coordinates less than or
            # equal to `x_data[..., 0]`.
            # `values_max` are extrapolated values for x-coordinates greater than
            # `x_data[..., -1]`.

            values_min = tf.expand_dims(
                y_data[..., 0], -1) + left_slope * (x - tf.broadcast_to(
                    tf.expand_dims(x_data[..., 0], -1), shape=tf.shape(x)))
            values_max = tf.expand_dims(
                y_data[..., -1], -1) + right_slope * (x - tf.broadcast_to(
                    tf.expand_dims(x_data[..., -1], -1), shape=tf.shape(x)))

            # `tf.where` evaluates all branches, need to cap indices to ensure it
            # won't go out of bounds.
            lower_encoding = tf.math.maximum(upper_indices - 1, 0)
            upper_encoding = tf.math.minimum(upper_indices, x_data_size - 1)
            # Prepare indices for `tf.gather_nd` or `tf.one_hot`
            # TODO(b/156720909): Extract get_slice logic into a common utilities
            # module for cubic and linear interpolation
            if optimize_for_tpu:
                lower_encoding = tf.one_hot(lower_encoding,
                                            x_data_size,
                                            dtype=dtype)
                upper_encoding = tf.one_hot(upper_encoding,
                                            x_data_size,
                                            dtype=dtype)

            def get_slice(x, encoding):
                if optimize_for_tpu:
                    return tf.math.reduce_sum(tf.expand_dims(x, axis=-2) *
                                              encoding,
                                              axis=-1)
                else:
                    return tf.gather(x,
                                     encoding,
                                     axis=-1,
                                     batch_dims=x.shape.rank - 1)

            x_data_lower = get_slice(x_data, lower_encoding)
            x_data_upper = get_slice(x_data, upper_encoding)
            y_data_lower = get_slice(y_data, lower_encoding)
            y_data_upper = get_slice(y_data, upper_encoding)

            # Nan in unselected branches could propagate through gradient calculation,
            # hence we need to clip the values to ensure no nan would occur. In this
            # case we need to ensure there is no division by zero.
            x_data_diff = x_data_upper - x_data_lower
            floor_x_diff = tf.where(at_min | at_max, x_data_diff + 1,
                                    x_data_diff)
            interpolated = y_data_lower + (x - x_data_lower) * (
                y_data_upper - y_data_lower) / floor_x_diff

            interpolated = tf.where(at_min, values_min, interpolated)
            interpolated = tf.where(at_max, values_max, interpolated)
            if batch_shape:
                return interpolated
            else:
                return tf.squeeze(interpolated, 0)
def interpolate(x_values,
                spline_data,
                optimize_for_tpu=False,
                dtype=None,
                name=None):
    """Interpolates spline values for the given `x_values` and the `spline_data`.

  Constant extrapolation is performed for the values outside the domain
  `spline_data.x_data`. This means that for `x > max(spline_data.x_data)`,
  `interpolate(x, spline_data) = spline_data.y_data[-1]`
  and for  `x < min(spline_data.x_data)`,
  `interpolate(x, spline_data) = spline_data.y_data[0]`.

  For the interpolation formula refer to p.548 of [1].

  #### References:
  [1]: R. Sedgewick, Algorithms in C, 1990, p. 545-550.
    Link: http://index-of.co.uk/Algorithms/Algorithms%20in%20C.pdf

  Args:
    x_values: A real `Tensor` of shape `batch_shape + [num_points]`.
    spline_data: An instance of `SplineParameters`. `spline_data.x_data` should
      have the same batch shape as `x_values`.
    optimize_for_tpu: A Python bool. If `True`, the algorithm uses one-hot
      encoding to lookup indices of `x_values` in `spline_data.x_data`. This
      significantly improves performance of the algorithm on a TPU device but
      may slow down performance on the CPU.
      Default value: `False`.
    dtype: Optional dtype for `x_values`.
      Default value: `None` which maps to the default dtype inferred by
      TensorFlow.
    name: Python `str` name prefixed to ops created by this function.
      Default value: `None` which is mapped to the default name
      `cubic_spline_interpolate`.

  Returns:
      A `Tensor` of the same shape and `dtype` as `x_values`. Represents
      the interpolated values.

  Raises:
    ValueError:
      If `x_values` batch shape is different from `spline_data.x_data` batch
      shape.
  """
    name = name or "cubic_spline_interpolate"
    with tf.name_scope(name):
        x_values = tf.convert_to_tensor(x_values, dtype=dtype, name="x_values")
        dtype = x_values.dtype
        # Unpack the spline data
        x_data = spline_data.x_data
        y_data = spline_data.y_data
        spline_coeffs = spline_data.spline_coeffs
        # Try broadcast batch_shapes
        x_values, x_data = utils.broadcast_common_batch_shape(x_values, x_data)
        x_values, y_data = utils.broadcast_common_batch_shape(x_values, y_data)
        x_values, spline_coeffs = utils.broadcast_common_batch_shape(
            x_values, spline_coeffs)
        # Determine the splines to use.
        indices = tf.searchsorted(x_data, x_values, side="right") - 1
        # This selects all elements for the start of the spline interval.
        # Make sure indices lie in the permissible range
        lower_encoding = tf.maximum(indices, 0)
        # This selects all elements for the end of the spline interval.
        # Make sure indices lie in the permissible range
        upper_encoding = tf.minimum(indices + 1,
                                    x_data.shape.as_list()[-1] - 1)
        # Prepare indices for `tf.gather_nd` or `tf.one_hot`
        # TODO(b/156720909): Extract get_slice logic into a common utilities module
        # for cubic and linear interpolation
        if optimize_for_tpu:
            x_data_size = x_data.shape.as_list()[-1]
            lower_encoding = tf.one_hot(lower_encoding,
                                        x_data_size,
                                        dtype=dtype)
            upper_encoding = tf.one_hot(upper_encoding,
                                        x_data_size,
                                        dtype=dtype)
        # Calculate dx and dy.
        # Simplified logic:
        # dx = x_data[indices + 1] - x_data[indices]
        # dy = y_data[indices + 1] - y_data[indices]
        # indices is a tensor with different values per row/spline
        # Hence use a selection matrix with gather_nd
        def get_slice(x, encoding):
            if optimize_for_tpu:
                return tf.math.reduce_sum(tf.expand_dims(x, axis=-2) *
                                          encoding,
                                          axis=-1)
            else:
                return tf.gather(x,
                                 encoding,
                                 axis=-1,
                                 batch_dims=x.shape.rank - 1)

        x0 = get_slice(x_data, lower_encoding)
        x1 = get_slice(x_data, upper_encoding)
        dx = x1 - x0

        y0 = get_slice(y_data, lower_encoding)
        y1 = get_slice(y_data, upper_encoding)
        dy = y1 - y0

        spline_coeffs0 = get_slice(spline_coeffs, lower_encoding)
        spline_coeffs1 = get_slice(spline_coeffs, upper_encoding)

        t = (x_values - x0) / dx
        t = tf.where(dx > 0, t, tf.zeros_like(t))
        df = ((t + 1.0) * spline_coeffs1 * 2.0) - (
            (t - 2.0) * spline_coeffs0 * 2.0)
        df1 = df * t * (t - 1) / 6.0
        result = y0 + (t * dy) + (dx * dx * df1)
        # Use constant extrapolation outside the domain
        upper_bound = tf.expand_dims(tf.reduce_max(x_data, -1),
                                     -1) + tf.zeros_like(result)
        lower_bound = tf.expand_dims(tf.reduce_min(x_data, -1),
                                     -1) + tf.zeros_like(result)
        result = tf.where(
            tf.logical_and(x_values <= upper_bound, x_values >= lower_bound),
            result, tf.where(x_values > upper_bound, y0, y1))
        return result
def find_interval_index(query_xs,
                        interval_lower_xs,
                        last_interval_is_closed=False,
                        dtype=None,
                        name=None):
    """Function to find the index of the interval where query points lies.

  Given a list of adjacent half-open intervals [x_0, x_1), [x_1, x_2), ...,
  [x_{n-1}, x_n), [x_n, inf), described by a list [x_0, x_1, ..., x_{n-1}, x_n].
  Return the index where the input query points lie. If x >= x_n, n is returned,
  and if x < x_0, -1 is returned. If `last_interval_is_closed` is set to `True`,
  the last interval [x_{n-1}, x_n] is interpreted as closed (including x_n).

  #### Example

  ```python
  interval_lower_xs = [0.25, 0.5, 1.0, 2.0, 3.0]
  query_xs = [0.25, 3.0, 5.0, 0.0, 0.5, 0.8]
  result = find_interval_index(query_xs, interval_lower_xs)
  # result == [0, 4, 4, -1, 1, 1]
  ```

  Args:
    query_xs: Rank 1 real `Tensor` of any size, the list of x coordinates for
      which the interval index is to be found. The values must be strictly
      increasing.
    interval_lower_xs: Rank 1 `Tensor` of the same shape and dtype as
      `query_xs`. The values x_0, ..., x_n that define the interval starts.
    last_interval_is_closed: If set to `True`, the last interval is interpreted
      as closed.
    dtype: Optional `tf.Dtype`. If supplied, the dtype for `query_xs` and
      `interval_lower_xs`.
      Default value: None which maps to the default dtype inferred by TensorFlow
        (float32).
    name: Optional name of the operation.

  Returns:
    A tensor that matches the shape of `query_xs` with dtype=int32 containing
    the indices of the intervals containing query points. `-1` means the query
    point lies before all intervals and `n-1` means that the point lies in the
    last half-open interval (if `last_interval_is_closed` is `False`) or that
    the point lies to the right of all intervals (if `last_interval_is_closed`
    is `True`).
  """
    with tf.compat.v1.name_scope(
            name,
            default_name='find_interval_index',
            values=[query_xs, interval_lower_xs, last_interval_is_closed]):
        # TODO(b/138988951): add ability to validate that intervals are increasing.
        # TODO(b/138988951): validate that if last_interval_is_closed, input size
        # must be > 1.
        query_xs = tf.convert_to_tensor(query_xs, dtype=dtype)
        interval_lower_xs = tf.convert_to_tensor(interval_lower_xs,
                                                 dtype=dtype)

        # Result assuming that last interval is half-open.
        indices = tf.searchsorted(interval_lower_xs, query_xs,
                                  side='right') - 1

        # Handling the branch if the last interval is closed.
        last_index = tf.shape(interval_lower_xs)[-1] - 1
        last_x = tf.gather(interval_lower_xs, [last_index], axis=-1)
        # should_cap is a tensor true where a cell is true iff indices is the last
        # index at that cell and the query x <= the right boundary of the last
        # interval.
        should_cap = tf.logical_and(tf.equal(indices, last_index),
                                    tf.less_equal(query_xs, last_x))

        # cap to last_index if the query x is not in the last interval, otherwise,
        # cap to last_index - 1.
        caps = last_index - tf.cast(should_cap, dtype=tf.dtypes.int32)

        return tf.compat.v1.where(last_interval_is_closed,
                                  tf.minimum(indices, caps), indices)
 def bizday_back(x):
     left = tf.searchsorted(bizday_at_holidays, x, side='left')
     ordinal = x + left - 1
     return ordinal
Exemple #25
0
def find_bins(x,
              edges,
              extend_lower_interval=False,
              extend_upper_interval=False,
              dtype=None,
              name=None):
    """Bin values into discrete intervals.

  Given `edges = [c0, ..., cK]`, defining intervals
  `I0 = [c0, c1)`, `I1 = [c1, c2)`, ..., `I_{K-1} = [c_{K-1}, cK]`,
  This function returns `bins`, such that:
  `edges[bins[i]] <= x[i] < edges[bins[i] + 1]`.

  Args:
    x:  Numeric `N-D` `Tensor` with `N > 0`.
    edges:  `Tensor` of same `dtype` as `x`.  The first dimension indexes edges
      of intervals.  Must either be `1-D` or have
      `x.shape[1:] == edges.shape[1:]`.  If `rank(edges) > 1`, `edges[k]`
      designates a shape `edges.shape[1:]` `Tensor` of bin edges for the
      corresponding dimensions of `x`.
    extend_lower_interval:  Python `bool`.  If `True`, extend the lowest
      interval `I0` to `(-inf, c1]`.
    extend_upper_interval:  Python `bool`.  If `True`, extend the upper
      interval `I_{K-1}` to `[c_{K-1}, +inf)`.
    dtype: The output type (`int32` or `int64`). `Default value:` `x.dtype`.
      This effects the output values when `x` is below/above the intervals,
      which will be `-1/K+1` for `int` types and `NaN` for `float`s.
      At indices where `x` is `NaN`, the output values will be `0` for `int`
      types and `NaN` for floats.
    name:  A Python string name to prepend to created ops. Default: 'find_bins'

  Returns:
    bins: `Tensor` with same `shape` as `x` and `dtype`.
      Has whole number values.  `bins[i] = k` means the `x[i]` falls into the
      `kth` bin, ie, `edges[bins[i]] <= x[i] < edges[bins[i] + 1]`.

  Raises:
    ValueError:  If `edges.shape[0]` is determined to be less than 2.

  #### Examples

  Cut a `1-D` array

  ```python
  x = [0., 5., 6., 10., 20.]
  edges = [0., 5., 10.]
  tfp.stats.find_bins(x, edges)
  ==> [0., 0., 1., 1., np.nan]
  ```

  Cut `x` into its deciles

  ```python
  x = tf.random_uniform(shape=(100, 200))
  decile_edges = tfp.stats.quantiles(x, num_quantiles=10)
  bins = tfp.stats.find_bins(x, edges=decile_edges)
  bins.shape
  ==> (100, 200)
  tf.reduce_mean(bins == 0.)
  ==> approximately 0.1
  tf.reduce_mean(bins == 1.)
  ==> approximately 0.1
  ```

  """
    # TFP users may be surprised to see the "action" in the leftmost dim of
    # edges, rather than the rightmost (event) dim.  Why?
    # 1. Most likely you created edges by getting quantiles over samples, and
    #    quantile/percentile return these edges in the leftmost (sample) dim.
    # 2. Say you have event_shape = [5], then we expect the bin will be different
    #    for all 5 events, so the index of the bin should not be in the event dim.
    with tf1.name_scope(name, default_name='find_bins', values=[x, edges]):
        in_type = dtype_util.common_dtype([x, edges], dtype_hint=tf.float32)
        edges = tf.convert_to_tensor(value=edges, name='edges', dtype=in_type)
        x = tf.convert_to_tensor(value=x, name='x', dtype=in_type)

        if (tf.compat.dimension_value(edges.shape[0]) is not None
                and tf.compat.dimension_value(edges.shape[0]) < 2):
            raise ValueError(
                'First dimension of `edges` must have length > 1 to index 1 or '
                'more bin. Found: {}'.format(edges.shape))

        flattening_x = edges.shape.ndims == 1 and x.shape.ndims > 1

        if flattening_x:
            x_orig_shape = tf.shape(input=x)
            x = tf.reshape(x, [-1])

        if dtype is None:
            dtype = in_type
        dtype = tf.as_dtype(dtype)

        # Move first dims into the rightmost.
        x_permed = distribution_util.rotate_transpose(x, shift=-1)
        edges_permed = distribution_util.rotate_transpose(edges, shift=-1)

        # If...
        #   x_permed = [0, 1, 6., 10]
        #   edges = [0, 5, 10.]
        #   ==> almost_output = [0, 1, 2, 2]
        searchsorted_type = dtype if dtype in [tf.int32, tf.int64] else None
        almost_output_permed = tf.searchsorted(sorted_sequence=edges_permed,
                                               values=x_permed,
                                               side='right',
                                               out_type=searchsorted_type)
        # Move the rightmost dims back to the leftmost.
        almost_output = tf.cast(
            distribution_util.rotate_transpose(almost_output_permed, shift=1),
            dtype)

        # In above example, we want [0, 0, 1, 1], so correct this here.
        bins = tf.clip_by_value(almost_output - 1, tf.cast(0, dtype),
                                tf.cast(tf.shape(input=edges)[0] - 2, dtype))

        if not extend_lower_interval:
            low_fill = np.nan if dtype.is_floating else -1
            bins = tf.where(x < tf.expand_dims(edges[0], 0),
                            tf.cast(low_fill, dtype), bins)

        if not extend_upper_interval:
            up_fill = np.nan if dtype.is_floating else tf.shape(
                input=edges)[0] - 1
            bins = tf.where(x > tf.expand_dims(edges[-1], 0),
                            tf.cast(up_fill, dtype), bins)

        if flattening_x:
            bins = tf.reshape(bins, x_orig_shape)

        return bins
def swaption_price(*,
                   expiries,
                   floating_leg_start_times,
                   floating_leg_end_times,
                   fixed_leg_payment_times,
                   floating_leg_daycount_fractions,
                   fixed_leg_daycount_fractions,
                   fixed_leg_coupon,
                   reference_rate_fn,
                   dim,
                   mean_reversion,
                   volatility,
                   notional=None,
                   is_payer_swaption=None,
                   use_analytic_pricing=True,
                   num_samples=1,
                   random_type=None,
                   seed=None,
                   skip=0,
                   time_step=None,
                   dtype=None,
                   name=None):
  """Calculates the price of European Swaptions using the Hull-White model.

  A European Swaption is a contract that gives the holder an option to enter a
  swap contract at a future date at a prespecified fixed rate. A swaption that
  grants the holder to pay fixed rate and receive floating rate is called a
  payer swaption while the swaption that grants the holder to receive fixed and
  pay floating payments is called the receiver swaption. Typically the start
  date (or the inception date) of the swap concides with the expiry of the
  swaption. Mid-curve swaptions are currently not supported (b/160061740).

  Analytic pricing of swaptions is performed using the Jamshidian decomposition
  [1].

  #### References:
    [1]: D. Brigo, F. Mercurio. Interest Rate Models-Theory and Practice.
    Second Edition. 2007.

  #### Example
  The example shows how value a batch of 1y x 1y and 1y x 2y swaptions using the
  Hull-White model.

  ````python
  import numpy as np
  import tensorflow.compat.v2 as tf
  import tf_quant_finance as tff

  dtype = tf.float64

  expiries = [1.0, 1.0]
  float_leg_start_times = [[1.0, 1.25, 1.5, 1.75, 2.0, 2.0, 2.0, 2.0],
                            [1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75]]
  float_leg_end_times = [[1.25, 1.5, 1.75, 2.0, 2.0, 2.0, 2.0, 2.0],
                          [1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75, 3.0]]
  fixed_leg_payment_times = [[1.25, 1.5, 1.75, 2.0, 2.0, 2.0, 2.0, 2.0],
                          [1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75, 3.0]]
  float_leg_daycount_fractions = [[0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0],
                              [0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]]
  fixed_leg_daycount_fractions = [[0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0],
                              [0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]]
  fixed_leg_coupon = [[0.011, 0.011, 0.011, 0.011, 0.0, 0.0, 0.0, 0.0],
                      [0.011, 0.011, 0.011, 0.011, 0.011, 0.011, 0.011, 0.011]]
  zero_rate_fn = lambda x: 0.01 * tf.ones_like(x, dtype=dtype)
  price = tff.models.hull_white.swaption_price(
      expiries=expiries,
      floating_leg_start_times=float_leg_start_times,
      floating_leg_end_times=float_leg_end_times,
      fixed_leg_payment_times=fixed_leg_payment_times,
      floating_leg_daycount_fractions=float_leg_daycount_fractions,
      fixed_leg_daycount_fractions=fixed_leg_daycount_fractions,
      fixed_leg_coupon=fixed_leg_coupon,
      reference_rate_fn=zero_rate_fn,
      notional=100.,
      dim=1,
      mean_reversion=[0.03],
      volatility=[0.02],
      dtype=dtype)
  # Expected value: [[0.7163243383624043], [1.4031415262337608]] # shape = (2,1)
  ````

  Args:
    expiries: A real `Tensor` of any shape and dtype. The time to
      expiration of the swaptions. The shape of this input determines the number
      (and shape) of swaptions to be priced and the shape of the output.
    floating_leg_start_times: A real `Tensor` of the same dtype as `expiries`.
      The times when accrual begins for each payment in the floating leg. The
      shape of this input should be `expiries.shape + [m]` where `m` denotes
      the number of floating payments in each leg.
    floating_leg_end_times: A real `Tensor` of the same dtype as `expiries`.
      The times when accrual ends for each payment in the floating leg. The
      shape of this input should be `expiries.shape + [m]` where `m` denotes
      the number of floating payments in each leg.
    fixed_leg_payment_times: A real `Tensor` of the same dtype as `expiries`.
      The payment times for each payment in the fixed leg. The shape of this
      input should be `expiries.shape + [n]` where `n` denotes the number of
      fixed payments in each leg.
    floating_leg_daycount_fractions: A real `Tensor` of the same dtype and
      compatible shape as `floating_leg_start_times`. The daycount fractions
      for each payment in the floating leg.
    fixed_leg_daycount_fractions: A real `Tensor` of the same dtype and
      compatible shape as `fixed_leg_payment_times`. The daycount fractions
      for each payment in the fixed leg.
    fixed_leg_coupon: A real `Tensor` of the same dtype and compatible shape
      as `fixed_leg_payment_times`. The fixed rate for each payment in the
      fixed leg.
    reference_rate_fn: A Python callable that accepts expiry time as a real
      `Tensor` and returns a `Tensor` of shape `input_shape + [dim]`. Returns
      the continuously compounded zero rate at the present time for the input
      expiry time.
    dim: A Python scalar which corresponds to the number of Hull-White Models
      to be used for pricing.
    mean_reversion: A real positive `Tensor` of shape `[dim]` or a Python
      callable. The callable can be one of the following:
      (a) A left-continuous piecewise constant object (e.g.,
      `tff.math.piecewise.PiecewiseConstantFunc`) that has a property
      `is_piecewise_constant` set to `True`. In this case the object should
      have a method `jump_locations(self)` that returns a `Tensor` of shape
      `[dim, num_jumps]` or `[num_jumps]`. In the first case,
      `mean_reversion(t)` should return a `Tensor` of shape `[dim] + t.shape`,
      and in the second, `t.shape + [dim]`, where `t` is a rank 1 `Tensor` of
      the same `dtype` as the output. See example in the class docstring.
      (b) A callable that accepts scalars (stands for time `t`) and returns a
      `Tensor` of shape `[dim]`.
      Corresponds to the mean reversion rate.
    volatility: A real positive `Tensor` of the same `dtype` as
      `mean_reversion` or a callable with the same specs as above.
      Corresponds to the lond run price variance.
    notional: An optional `Tensor` of same dtype and compatible shape as
      `strikes`specifying the notional amount for the underlying swap.
       Default value: None in which case the notional is set to 1.
    is_payer_swaption: A boolean `Tensor` of a shape compatible with `expiries`.
      Indicates whether the swaption is a payer (if True) or a receiver
      (if False) swaption. If not supplied, payer swaptions are assumed.
    use_analytic_pricing: A Python boolean specifying if analytic valuation
      should be performed. Analytic valuation is only supported for constant
      `mean_reversion` and piecewise constant `volatility`. If the input is
      `False`, then valuation using Monte-Carlo simulations is performed.
      Default value: The default value is `True`.
    num_samples: Positive scalar `int32` `Tensor`. The number of simulation
      paths during Monte-Carlo valuation. This input is ignored during analytic
      valuation.
      Default value: The default value is 1.
    random_type: Enum value of `RandomType`. The type of (quasi)-random
      number generator to use to generate the simulation paths. This input is
      relevant only for Monte-Carlo valuation and ignored during analytic
      valuation.
      Default value: `None` which maps to the standard pseudo-random numbers.
    seed: Seed for the random number generator. The seed is only relevant if
      `random_type` is one of
      `[STATELESS, PSEUDO, HALTON_RANDOMIZED, PSEUDO_ANTITHETIC,
        STATELESS_ANTITHETIC]`. For `PSEUDO`, `PSEUDO_ANTITHETIC` and
      `HALTON_RANDOMIZED` the seed should be an Python integer. For
      `STATELESS` and  `STATELESS_ANTITHETIC `must be supplied as an integer
      `Tensor` of shape `[2]`. This input is relevant only for Monte-Carlo
      valuation and ignored during analytic valuation.
      Default value: `None` which means no seed is set.
    skip: `int32` 0-d `Tensor`. The number of initial points of the Sobol or
      Halton sequence to skip. Used only when `random_type` is 'SOBOL',
      'HALTON', or 'HALTON_RANDOMIZED', otherwise ignored.
      Default value: `0`.
    time_step: Scalar real `Tensor`. Maximal distance between time grid points
      in Euler scheme. Relevant when Euler scheme is used for simulation. This
      input is ignored during analytic valuation.
      Default value: `None`.
    dtype: The default dtype to use when converting values to `Tensor`s.
      Default value: `None` which means that default dtypes inferred by
      TensorFlow are used.
    name: Python string. The name to give to the ops created by this function.
      Default value: `None` which maps to the default name
      `hw_swaption_price`.

  Returns:
    A `Tensor` of real dtype and shape  expiries.shape + [dim] containing the
    computed swaption prices. For swaptions that have. reset in the past
    (expiries<0), the function sets the corresponding option prices to 0.0.
  """
  # TODO(b/160061740): Extend the functionality to support mid-curve swaptions.
  name = name or 'hw_swaption_price'
  del floating_leg_daycount_fractions
  with tf.name_scope(name):
    expiries = tf.convert_to_tensor(expiries, dtype=dtype, name='expiries')
    dtype = dtype or expiries.dtype
    float_leg_start_times = tf.convert_to_tensor(
        floating_leg_start_times, dtype=dtype, name='float_leg_start_times')
    float_leg_end_times = tf.convert_to_tensor(
        floating_leg_end_times, dtype=dtype, name='float_leg_end_times')
    fixed_leg_payment_times = tf.convert_to_tensor(
        fixed_leg_payment_times, dtype=dtype, name='fixed_leg_payment_times')
    fixed_leg_daycount_fractions = tf.convert_to_tensor(
        fixed_leg_daycount_fractions, dtype=dtype,
        name='fixed_leg_daycount_fractions')
    fixed_leg_coupon = tf.convert_to_tensor(
        fixed_leg_coupon, dtype=dtype, name='fixed_leg_coupon')
    notional = tf.convert_to_tensor(notional, dtype=dtype, name='notional')
    notional = tf.expand_dims(
        tf.broadcast_to(notional, expiries.shape), axis=-1)
    if is_payer_swaption is None:
      is_payer_swaption = True
    is_payer_swaption = tf.convert_to_tensor(
        is_payer_swaption, dtype=tf.bool, name='is_payer_swaption')

    output_shape = expiries.shape.as_list() + [dim]
    # Add a dimension corresponding to multiple cashflows in a swap
    if expiries.shape.rank == fixed_leg_payment_times.shape.rank - 1:
      expiries = tf.expand_dims(expiries, axis=-1)
    elif expiries.shape.rank < fixed_leg_payment_times.shape.rank - 1:
      raise ValueError('Swaption expiries not specified for all swaptions '
                       'in the batch. Expected rank {} but received {}.'.format(
                           fixed_leg_payment_times.shape.rank - 1,
                           expiries.shape.rank))

    # Expected shape: batch_shape + [m], same as fixed_leg_payment_times.shape
    # We need to explicitly use tf.repeat because we need to price
    # batch_shape + [m] bond options with different strikes along the last
    # dimension.
    expiries = tf.repeat(
        expiries, fixed_leg_payment_times.shape.as_list()[-1], axis=-1)

    if use_analytic_pricing:
      return _analytic_valuation(expiries, float_leg_start_times,
                                 float_leg_end_times, fixed_leg_payment_times,
                                 fixed_leg_daycount_fractions,
                                 fixed_leg_coupon, reference_rate_fn,
                                 dim, mean_reversion, volatility, notional,
                                 is_payer_swaption, output_shape, dtype,
                                 name + '_analytic_valyation')

    # Monte-Carlo pricing
    model = vector_hull_white.VectorHullWhiteModel(
        dim,
        mean_reversion,
        volatility,
        initial_discount_rate_fn=reference_rate_fn,
        dtype=dtype)

    if time_step is None:
      raise ValueError('`time_step` must be provided for simulation '
                       'based bond option valuation.')

    sim_times, _ = tf.unique(tf.reshape(expiries, shape=[-1]))
    longest_expiry = tf.reduce_max(sim_times)
    sim_times, _ = tf.unique(tf.concat([sim_times, tf.range(
        time_step, longest_expiry, time_step)], axis=0))
    sim_times = tf.sort(sim_times, name='sort_sim_times')

    maturities = fixed_leg_payment_times
    swaptionlet_shape = maturities.shape
    tau = maturities - expiries

    curve_times_builder, _ = tf.unique(tf.reshape(tau, shape=[-1]))
    curve_times = tf.sort(curve_times_builder, name='sort_curve_times')

    p_t_tau, r_t = model.sample_discount_curve_paths(
        times=sim_times,
        curve_times=curve_times,
        num_samples=num_samples,
        random_type=random_type,
        seed=seed,
        skip=skip)

    dt = tf.concat(
        [tf.convert_to_tensor([0.0], dtype=dtype),
         sim_times[1:] - sim_times[:-1]], axis=0)
    dt = tf.expand_dims(tf.expand_dims(dt, axis=-1), axis=0)
    discount_factors_builder = tf.math.exp(-r_t * dt)
    # Transpose before (and after) because we want the cumprod along axis=1
    # and `matvec` operates on the last axis.
    discount_factors_builder = tf.transpose(
        utils.cumprod_using_matvec(
            tf.transpose(discount_factors_builder, [0, 2, 1])), [0, 2, 1])

    # make discount factors the same shape as `p_t_tau`. This involves adding
    # an extra dimenstion (corresponding to `curve_times`).
    discount_factors_builder = tf.expand_dims(
        discount_factors_builder,
        axis=1)
    # tf.repeat is needed because we will use gather_nd later on this tensor.
    discount_factors_simulated = tf.repeat(
        discount_factors_builder, p_t_tau.shape.as_list()[1], axis=1)

    # `sim_times` and `curve_times` are sorted for simulation. We need to
    # select the indices corresponding to our input.
    sim_time_index = tf.searchsorted(sim_times, tf.reshape(expiries, [-1]))
    curve_time_index = tf.searchsorted(curve_times, tf.reshape(tau, [-1]))

    gather_index = _prepare_indices(
        tf.range(0, num_samples), curve_time_index, sim_time_index,
        tf.range(0, dim))

    # The shape after `gather_nd` will be `(num_samples*num_swaptionlets*dim,)`
    payoff_discount_factors_builder = tf.gather_nd(
        discount_factors_simulated, gather_index)
    # Reshape to `[num_samples] + swaptionlet.shape + [dim]`
    payoff_discount_factors = tf.reshape(
        payoff_discount_factors_builder,
        [num_samples] + swaptionlet_shape + [dim])
    payoff_bond_price_builder = tf.gather_nd(p_t_tau, gather_index)
    payoff_bond_price = tf.reshape(
        payoff_bond_price_builder, [num_samples] + swaptionlet_shape + [dim])

    # Add an axis corresponding to `dim`
    fixed_leg_pv = tf.expand_dims(
        fixed_leg_coupon * fixed_leg_daycount_fractions,
        axis=-1) * payoff_bond_price
    # Sum fixed coupon payments within each swap
    fixed_leg_pv = tf.math.reduce_sum(fixed_leg_pv, axis=-2)
    float_leg_pv = 1.0 - payoff_bond_price[..., -1, :]
    payoff_swap = payoff_discount_factors[..., -1, :] * (
        float_leg_pv - fixed_leg_pv)
    payoff_swap = tf.where(is_payer_swaption, payoff_swap, -1.0 * payoff_swap)
    payoff_swaption = tf.math.maximum(payoff_swap, 0.0)
    option_value = tf.reshape(
        tf.math.reduce_mean(payoff_swaption, axis=0), output_shape)

    return notional * option_value
    def forward_rates(self,
                      market: pmd.ProcessedMarketData,
                      name: Optional[str] = None
                      ) -> Tuple[types.DateTensor, types.FloatTensor]:
        """Returns forward rates for the floating leg.

    Args:
      market: An instance of `ProcessedMarketData`.
      name: Python str. The name to give to the ops created by this function.
        Default value: `None` which maps to 'forward_rates'.

    Returns:
      A tuple of two `Tensor`s of shape `batch_shape + [num_cashflows]`
      containing the dates and the corresponding forward rates for each stream
      based on the input market data.
    """
        name = name or (self._name + "_forward_rates")
        with tf.name_scope(name):
            reference_curve = get_discount_curve(self._reference_curve_type,
                                                 market, self._reference_mask)
            valuation_date = dateslib.convert_to_date_tensor(market.date)

            # Previous fixing date
            coupon_start_date_ord = self._coupon_start_dates.ordinal()
            coupon_end_date_ord = self._coupon_end_dates.ordinal()
            valuation_date_ord = valuation_date.ordinal()
            batch_shape = tf.shape(coupon_start_date_ord)[:-1]
            # Broadcast valuation date batch shape for tf.searchsorted
            valuation_date_ord += tf.expand_dims(tf.zeros(batch_shape,
                                                          dtype=tf.int32),
                                                 axis=-1)
            ind = tf.maximum(
                tf.searchsorted(coupon_start_date_ord, valuation_date_ord) - 1,
                0)
            # Fixings are assumed to be the same as coupon start dates
            # TODO(b/177047910): add fixing settlement dates.
            # Shape `batch_shape + [1]`
            fixing_dates_ord = tf.gather(
                coupon_start_date_ord,
                ind,
                batch_dims=len(coupon_start_date_ord.shape) - 1)
            fixing_end_dates_ord = tf.gather(
                coupon_end_date_ord,
                ind,
                batch_dims=len(coupon_start_date_ord.shape) - 1)
            fixing_dates = dateslib.dates_from_ordinals(fixing_dates_ord)
            fixing_end_dates = dateslib.dates_from_ordinals(
                fixing_end_dates_ord)
            # Get fixings. Shape batch_shape + [1]
            past_fixing = _get_fixings(fixing_dates, fixing_end_dates,
                                       self._reference_curve_type,
                                       self._reference_mask, market)
            forward_rates = reference_curve.forward_rate(
                self._accrual_start_date,
                self._accrual_end_date,
                day_count_fraction=self._daycount_fractions)
            # Shape batch_shape + [num_cashflows]
            forward_rates = tf.where(self._daycount_fractions > 0.,
                                     forward_rates,
                                     tf.zeros_like(forward_rates))
            # If coupon end date is before the valuation date, the payment is in the
            # past. If valuation date is between coupon start date and coupon end
            # date, then the rate has been fixed but not paid. Otherwise the rate is
            # not fixed and should be read from the curve.
            # Shape batch_shape + [num_cashflows]
            forward_rates = tf.where(
                self._coupon_end_dates < valuation_date,
                tf.constant(0, dtype=self._dtype),
                tf.where(self._coupon_start_dates >= valuation_date,
                         forward_rates, past_fixing))
            return self._coupon_end_dates, forward_rates
def bond_option_price(
        *,
        strikes,
        expiries,
        maturities,
        discount_rate_fn,
        dim,
        mean_reversion,
        volatility,
        # TODO(b/159040541) Add correlation as an input.
        is_call_options=True,
        use_analytic_pricing=True,
        num_samples=1,
        random_type=None,
        seed=None,
        skip=0,
        time_step=None,
        dtype=None,
        name=None):
    """Calculates European bond option prices using the Hull-White model.

  Bond options are fixed income securities which give the holder a right to
  exchange at a future date (the option expiry) a zero coupon bond for a fixed
  price (the strike of the option). The maturity date of the bond is after the
  the expiry of the option. If `P(t,T)` denotes the price at time `t` of a zero
  coupon bond with maturity `T`, then the payoff from the option at option
  expiry, `T0`, is given by:

  ```None
  payoff = max(P(T0, T) - X, 0)
  ```
  where `X` is the strike price of the option.

  #### Example

  ````python
  import numpy as np
  import tensorflow.compat.v2 as tf
  import tf_quant_finance as tff

  dtype = tf.float64

  discount_rate_fn = lambda x: 0.01 * tf.ones_like(x, dtype=dtype)
  expiries = np.array([1.0])
  maturities = np.array([5.0])
  strikes = np.exp(-0.01 * maturities) / np.exp(-0.01 * expiries)
  price = tff.models.hull_white.bond_option_price(
      strikes=strikes,
      expiries=expiries,
      maturities=maturities,
      dim=1,
      mean_reversion=[0.03],
      volatility=[0.02],
      discount_rate_fn=discount_rate_fn,
      use_analytic_pricing=True,
      dtype=dtype)
  # Expected value: [[0.02817777]]
  ````

  Args:
    strikes: A real `Tensor` of any shape and dtype. The strike price of the
      options. The shape of this input determines the number (and shape) of the
      options to be priced and the output.
    expiries: A real `Tensor` of the same dtype and compatible shape as
      `strikes`.  The time to expiry of each bond option.
    maturities: A real `Tensor` of the same dtype and compatible shape as
      `strikes`.  The time to maturity of the underlying zero coupon bonds.
    discount_rate_fn: A Python callable that accepts expiry time as a real
      `Tensor` and returns a `Tensor` of shape `input_shape + dim`. Computes
      the zero coupon bond yield at the present time for the input expiry time.
    dim: A Python scalar which corresponds to the number of Hull-White Models
      to be used for pricing.
    mean_reversion: A real positive `Tensor` of shape `[dim]` or a Python
      callable. The callable can be one of the following:
      (a) A left-continuous piecewise constant object (e.g.,
      `tff.math.piecewise.PiecewiseConstantFunc`) that has a property
      `is_piecewise_constant` set to `True`. In this case the object should
      have a method `jump_locations(self)` that returns a `Tensor` of shape
      `[dim, num_jumps]` or `[num_jumps]`. In the first case,
      `mean_reversion(t)` should return a `Tensor` of shape `[dim] + t.shape`,
      and in the second, `t.shape + [dim]`, where `t` is a rank 1 `Tensor` of
      the same `dtype` as the output. See example in the class docstring.
      (b) A callable that accepts scalars (stands for time `t`) and returns a
      `Tensor` of shape `[dim]`.
      Corresponds to the mean reversion rate.
    volatility: A real positive `Tensor` of the same `dtype` as
      `mean_reversion` or a callable with the same specs as above.
      Corresponds to the lond run price variance.
    is_call_options: A boolean `Tensor` of a shape compatible with
      `strikes`. Indicates whether the option is a call (if True) or a put
      (if False). If not supplied, call options are assumed.
    use_analytic_pricing: A Python boolean specifying if analytic valuation
      should be performed. Analytic valuation is only supported for constant
      `mean_reversion` and piecewise constant `volatility`. If the input is
      `False`, then valuation using Monte-Carlo simulations is performed.
    num_samples: Positive scalar `int32` `Tensor`. The number of simulation
      paths during Monte-Carlo valuation. This input is ignored during analytic
      valuation.
      Default value: The default value is 1.
    random_type: Enum value of `RandomType`. The type of (quasi)-random
      number generator to use to generate the simulation paths. This input is
      relevant only for Monte-Carlo valuation and ignored during analytic
      valuation.
      Default value: `None` which maps to the standard pseudo-random numbers.
    seed: Seed for the random number generator. The seed is only relevant if
      `random_type` is one of
      `[STATELESS, PSEUDO, HALTON_RANDOMIZED, PSEUDO_ANTITHETIC,
        STATELESS_ANTITHETIC]`. For `PSEUDO`, `PSEUDO_ANTITHETIC` and
      `HALTON_RANDOMIZED` the seed should be an Python integer. For
      `STATELESS` and  `STATELESS_ANTITHETIC `must be supplied as an integer
      `Tensor` of shape `[2]`. This input is relevant only for Monte-Carlo
      valuation and ignored during analytic valuation.
      Default value: `None` which means no seed is set.
    skip: `int32` 0-d `Tensor`. The number of initial points of the Sobol or
      Halton sequence to skip. Used only when `random_type` is 'SOBOL',
      'HALTON', or 'HALTON_RANDOMIZED', otherwise ignored.
      Default value: `0`.
    time_step: Scalar real `Tensor`. Maximal distance between time grid points
      in Euler scheme. Relevant when Euler scheme is used for simulation. This
      input is ignored during analytic valuation.
      Default value: `None`.
    dtype: The default dtype to use when converting values to `Tensor`s.
      Default value: `None` which means that default dtypes inferred by
      TensorFlow are used.
    name: Python string. The name to give to the ops created by this class.
      Default value: `None` which maps to the default name
      `hw_bond_option_price`.

  Returns:
    A `Tensor` of real dtype and shape  `strikes.shape + [dim]` containing the
    computed option prices.
  """
    name = name or 'hw_bond_option_price'
    if dtype is None:
        dtype = tf.convert_to_tensor([0.0]).dtype
    with tf.name_scope(name):
        strikes = tf.convert_to_tensor(strikes, dtype=dtype, name='strikes')
        expiries = tf.convert_to_tensor(expiries, dtype=dtype, name='expiries')
        maturities = tf.convert_to_tensor(maturities,
                                          dtype=dtype,
                                          name='maturities')
        is_call_options = tf.convert_to_tensor(is_call_options,
                                               dtype=tf.bool,
                                               name='is_call_options')
        model = vector_hull_white.VectorHullWhiteModel(
            dim,
            mean_reversion=mean_reversion,
            volatility=volatility,
            initial_discount_rate_fn=discount_rate_fn,
            dtype=dtype)

        if use_analytic_pricing:
            return _analytic_valuation(discount_rate_fn, model, strikes,
                                       expiries, maturities, dim,
                                       is_call_options)

        if time_step is None:
            raise ValueError('`time_step` must be provided for simulation '
                             'based bond option valuation.')

        sim_times, _ = tf.unique(tf.reshape(expiries, shape=[-1]))
        longest_expiry = tf.reduce_max(sim_times)
        sim_times, _ = tf.unique(
            tf.concat(
                [sim_times,
                 tf.range(time_step, longest_expiry, time_step)],
                axis=0))
        sim_times = tf.sort(sim_times, name='sort_sim_times')
        tau = maturities - expiries
        curve_times_builder, _ = tf.unique(tf.reshape(tau, shape=[-1]))
        curve_times = tf.sort(curve_times_builder, name='sort_curve_times')

        p_t_tau, r_t = model.sample_discount_curve_paths(
            times=sim_times,
            curve_times=curve_times,
            num_samples=num_samples,
            random_type=random_type,
            seed=seed,
            skip=skip)

        dt_builder = tf.concat([
            tf.convert_to_tensor([0.0], dtype=dtype),
            sim_times[1:] - sim_times[:-1]
        ],
                               axis=0)
        dt = tf.expand_dims(tf.expand_dims(dt_builder, axis=-1), axis=0)
        discount_factors_builder = tf.math.exp(-r_t * dt)
        # Transpose before (and after) because we want the cumprod along axis=1
        # and `matvec` operates on the last axis. The shape before and after would
        # be `(num_samples, len(times), dim)`
        discount_factors_builder = tf.transpose(
            _cumprod_using_matvec(
                tf.transpose(discount_factors_builder, [0, 2, 1])), [0, 2, 1])

        # make discount factors the same shape as `p_t_tau`. This involves adding
        # an extra dimenstion (corresponding to `curve_times`).
        discount_factors_builder = tf.expand_dims(discount_factors_builder,
                                                  axis=1)
        discount_factors_simulated = tf.repeat(discount_factors_builder,
                                               p_t_tau.shape.as_list()[1],
                                               axis=1)

        # `sim_times` and `curve_times` are sorted for simulation. We need to
        # select the indices corresponding to our input.
        sim_time_index = tf.searchsorted(sim_times, tf.reshape(expiries, [-1]))
        curve_time_index = tf.searchsorted(curve_times, tf.reshape(tau, [-1]))

        gather_index = _prepare_indices(tf.range(0, num_samples),
                                        curve_time_index, sim_time_index,
                                        tf.range(0, dim))

        # The shape after `gather_nd` would be (num_samples*num_strikes*dim,)
        payoff_discount_factors_builder = tf.gather_nd(
            discount_factors_simulated, gather_index)
        # Reshape to `[num_samples] + strikes.shape + [dim]`
        payoff_discount_factors = tf.reshape(payoff_discount_factors_builder,
                                             [num_samples] + strikes.shape +
                                             [dim])
        payoff_bond_price_builder = tf.gather_nd(p_t_tau, gather_index)
        payoff_bond_price = tf.reshape(payoff_bond_price_builder,
                                       [num_samples] + strikes.shape + [dim])

        is_call_options = tf.reshape(
            tf.broadcast_to(is_call_options, strikes.shape),
            [1] + strikes.shape + [1])

        strikes = tf.reshape(strikes, [1] + strikes.shape + [1])
        payoff = tf.where(is_call_options,
                          tf.math.maximum(payoff_bond_price - strikes, 0.0),
                          tf.math.maximum(strikes - payoff_bond_price, 0.0))
        option_value = tf.math.reduce_mean(payoff_discount_factors * payoff,
                                           axis=0)

        return option_value
def bermudan_swaption_price(*,
                            exercise_times,
                            floating_leg_start_times,
                            floating_leg_end_times,
                            fixed_leg_payment_times,
                            floating_leg_daycount_fractions,
                            fixed_leg_daycount_fractions,
                            fixed_leg_coupon,
                            reference_rate_fn,
                            dim,
                            mean_reversion,
                            volatility,
                            notional=None,
                            is_payer_swaption=None,
                            lsm_basis=None,
                            num_samples=100,
                            random_type=None,
                            seed=None,
                            skip=0,
                            time_step=None,
                            dtype=None,
                            name=None):
  """Calculates the price of Bermudan Swaptions using the Hull-White model.

  A Bermudan Swaption is a contract that gives the holder an option to enter a
  swap contract on a set of future exercise dates. The exercise dates are
  typically the fixing dates (or a subset thereof) of the underlying swap. If
  `T_N` denotes the final payoff date and `T_i, i = {1,...,n}` denote the set
  of exercise dates, then if the option is exercised at `T_i`, the holder is
  left with a swap with first fixing date equal to `T_i` and maturity `T_N`.

  Simulation based pricing of Bermudan swaptions is performed using the least
  squares Monte-carlo approach [1].

  #### References:
    [1]: D. Brigo, F. Mercurio. Interest Rate Models-Theory and Practice.
    Second Edition. 2007.

  #### Example
  The example shows how value a batch of 5-no-call-1 and 5-no-call-2
  swaptions using the Hull-White model.

  ````python
  import numpy as np
  import tensorflow.compat.v2 as tf
  import tf_quant_finance as tff

  dtype = tf.float64

  exercise_swaption_1 = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5]
  exercise_swaption_2 = [2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.0]
  exercise_times = [exercise_swaption_1, exercise_swaption_2]

  float_leg_start_times_1y = [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5]
  float_leg_start_times_18m = [1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0]
  float_leg_start_times_2y = [2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.0]
  float_leg_start_times_30m = [2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.0, 5.0]
  float_leg_start_times_3y = [3.0, 3.5, 4.0, 4.5, 5.0, 5.0, 5.0, 5.0]
  float_leg_start_times_42m = [3.5, 4.0, 4.5, 5.0, 5.0, 5.0, 5.0, 5.0]
  float_leg_start_times_4y = [4.0, 4.5, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0]
  float_leg_start_times_54m = [4.5, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0]
  float_leg_start_times_5y = [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0]

  float_leg_start_times_swaption_1 = [float_leg_start_times_1y,
                                      float_leg_start_times_18m,
                                      float_leg_start_times_2y,
                                      float_leg_start_times_30m,
                                      float_leg_start_times_3y,
                                      float_leg_start_times_42m,
                                      float_leg_start_times_4y,
                                      float_leg_start_times_54m]

  float_leg_start_times_swaption_2 = [float_leg_start_times_2y,
                                      float_leg_start_times_30m,
                                      float_leg_start_times_3y,
                                      float_leg_start_times_42m,
                                      float_leg_start_times_4y,
                                      float_leg_start_times_54m,
                                      float_leg_start_times_5y,
                                      float_leg_start_times_5y]
  float_leg_start_times = [float_leg_start_times_swaption_1,
                         float_leg_start_times_swaption_2]

  float_leg_end_times = np.clip(np.array(float_leg_start_times) + 0.5, 0.0, 5.0)

  fixed_leg_payment_times = float_leg_end_times
  float_leg_daycount_fractions = (np.array(float_leg_end_times) -
                                  np.array(float_leg_start_times))
  fixed_leg_daycount_fractions = float_leg_daycount_fractions
  fixed_leg_coupon = 0.011 * np.ones_like(fixed_leg_payment_times)
  zero_rate_fn = lambda x: 0.01 * tf.ones_like(x, dtype=dtype)
  price = bermudan_swaption_price(
      exercise_times=exercise_times,
      floating_leg_start_times=float_leg_start_times,
      floating_leg_end_times=float_leg_end_times,
      fixed_leg_payment_times=fixed_leg_payment_times,
      floating_leg_daycount_fractions=float_leg_daycount_fractions,
      fixed_leg_daycount_fractions=fixed_leg_daycount_fractions,
      fixed_leg_coupon=fixed_leg_coupon,
      reference_rate_fn=zero_rate_fn,
      notional=100.,
      dim=1,
      mean_reversion=[0.03],
      volatility=[0.01],
      num_samples=1000000,
      time_step=0.1,
      random_type=tff.math.random.RandomType.PSEUDO_ANTITHETIC,
      seed=0,
      dtype=dtype)
  # Expected value: [1.8913050118443016, 1.6618681421434984] # shape = (2,)
  ````

  Args:
    exercise_times: A real `Tensor` of any shape `batch_shape + [num_exercise]`
      `and real dtype. The times corresponding to exercise dates of the
      swaptions. `num_exercise` corresponds to the number of exercise dates for
      the Bermudan swaption. The shape of this input determines the number (and
      shape) of Bermudan swaptions to be priced and the shape of the output.
    floating_leg_start_times: A real `Tensor` of the same dtype as
      `exercise_times`. The times when accrual begins for each payment in the
      floating leg upon exercise of the option. The shape of this input should
      be `exercise_times.shape + [m]` where `m` denotes the number of floating
      payments in each leg of the underlying swap until the swap maturity.
    floating_leg_end_times: A real `Tensor` of the same dtype as
      `exercise_times`. The times when accrual ends for each payment in the
      floating leg upon exercise of the option. The shape of this input should
      be `exercise_times.shape + [m]` where `m` denotes the number of floating
      payments in each leg of the underlying swap until the swap maturity.
    fixed_leg_payment_times: A real `Tensor` of the same dtype as
      `exercise_times`. The payment times for each payment in the fixed leg.
      The shape of this input should be `exercise_times.shape + [n]` where `n`
      denotes the number of fixed payments in each leg of the underlying swap
      until the swap maturity.
    floating_leg_daycount_fractions: A real `Tensor` of the same dtype and
      compatible shape as `floating_leg_start_times`. The daycount fractions
      for each payment in the floating leg.
    fixed_leg_daycount_fractions: A real `Tensor` of the same dtype and
      compatible shape as `fixed_leg_payment_times`. The daycount fractions
      for each payment in the fixed leg.
    fixed_leg_coupon: A real `Tensor` of the same dtype and compatible shape
      as `fixed_leg_payment_times`. The fixed rate for each payment in the
      fixed leg.
    reference_rate_fn: A Python callable that accepts expiry time as a real
      `Tensor` and returns a `Tensor` of shape `input_shape + [dim]`. Returns
      the continuously compounded zero rate at the present time for the input
      expiry time.
    dim: A Python scalar which corresponds to the number of Hull-White Models
      to be used for pricing.
    mean_reversion: A real positive `Tensor` of shape `[dim]` or a Python
      callable. The callable can be one of the following:
      (a) A left-continuous piecewise constant object (e.g.,
      `tff.math.piecewise.PiecewiseConstantFunc`) that has a property
      `is_piecewise_constant` set to `True`. In this case the object should
      have a method `jump_locations(self)` that returns a `Tensor` of shape
      `[dim, num_jumps]` or `[num_jumps]`. In the first case,
      `mean_reversion(t)` should return a `Tensor` of shape `[dim] + t.shape`,
      and in the second, `t.shape + [dim]`, where `t` is a rank 1 `Tensor` of
      the same `dtype` as the output. See example in the class docstring.
      (b) A callable that accepts scalars (stands for time `t`) and returns a
      `Tensor` of shape `[dim]`.
      Corresponds to the mean reversion rate.
    volatility: A real positive `Tensor` of the same `dtype` as
      `mean_reversion` or a callable with the same specs as above.
      Corresponds to the lond run price variance.
    notional: An optional `Tensor` of same dtype and compatible shape as
      `strikes`specifying the notional amount for the underlying swap.
       Default value: None in which case the notional is set to 1.
    is_payer_swaption: A boolean `Tensor` of a shape compatible with `expiries`.
      Indicates whether the swaption is a payer (if True) or a receiver
      (if False) swaption. If not supplied, payer swaptions are assumed.
    lsm_basis: A Python callable specifying the basis to be used in the LSM
      algorithm. The callable must accept a `Tensor`s of shape
      `[num_samples, dim]` and output `Tensor`s of shape `[m, num_samples]`
      where `m` is the nimber of basis functions used.
      Default value: `None`, in which case a polynomial basis of order 2 is
      used.
    num_samples: Positive scalar `int32` `Tensor`. The number of simulation
      paths during Monte-Carlo valuation. This input is ignored during analytic
      valuation.
      Default value: The default value is 100.
    random_type: Enum value of `RandomType`. The type of (quasi)-random
      number generator to use to generate the simulation paths.
      Default value: `None` which maps to the standard pseudo-random numbers.
    seed: Seed for the random number generator. The seed is only relevant if
      `random_type` is one of
      `[STATELESS, PSEUDO, HALTON_RANDOMIZED, PSEUDO_ANTITHETIC,
        STATELESS_ANTITHETIC]`. For `PSEUDO`, `PSEUDO_ANTITHETIC` and
      `HALTON_RANDOMIZED` the seed should be an Python integer. For
      `STATELESS` and  `STATELESS_ANTITHETIC `must be supplied as an integer
      `Tensor` of shape `[2]`.
      Default value: `None` which means no seed is set.
    skip: `int32` 0-d `Tensor`. The number of initial points of the Sobol or
      Halton sequence to skip. Used only when `random_type` is 'SOBOL',
      'HALTON', or 'HALTON_RANDOMIZED', otherwise ignored.
      Default value: `0`.
    time_step: Scalar real `Tensor`. Maximal distance between time grid points
      in Euler scheme. Relevant when Euler scheme is used for simulation.
      Default value: `None`.
    dtype: The default dtype to use when converting values to `Tensor`s.
      Default value: `None` which means that default dtypes inferred by
      TensorFlow are used.
    name: Python string. The name to give to the ops created by this function.
      Default value: `None` which maps to the default name
      `hw_bermudan_swaption_price`.

  Returns:
    A `Tensor` of real dtype and shape  batch_shape + [dim] containing the
    computed swaption prices.

  Raises:
    (a) `ValueError` if exercise_times.rank is less than
    floating_leg_start_times.rank - 1, which would mean exercise times are not
    specified for all swaptions.
    (b) `ValueError` if `time_step` is not specified for Monte-Carlo
    simulations.
    (c) `ValueError` if `dim` > 1.
  """
  if dim > 1:
    raise ValueError('dim > 1 is currently not supported.')

  name = name or 'hw_bermudan_swaption_price'
  del floating_leg_daycount_fractions, floating_leg_start_times
  del floating_leg_end_times
  with tf.name_scope(name):
    exercise_times = tf.convert_to_tensor(
        exercise_times, dtype=dtype, name='exercise_times')
    dtype = dtype or exercise_times.dtype
    fixed_leg_payment_times = tf.convert_to_tensor(
        fixed_leg_payment_times, dtype=dtype, name='fixed_leg_payment_times')
    fixed_leg_daycount_fractions = tf.convert_to_tensor(
        fixed_leg_daycount_fractions, dtype=dtype,
        name='fixed_leg_daycount_fractions')
    fixed_leg_coupon = tf.convert_to_tensor(
        fixed_leg_coupon, dtype=dtype, name='fixed_leg_coupon')
    notional = tf.convert_to_tensor(notional, dtype=dtype, name='notional')
    if is_payer_swaption is None:
      is_payer_swaption = True
    is_payer_swaption = tf.convert_to_tensor(
        is_payer_swaption, dtype=tf.bool, name='is_payer_swaption')

    if lsm_basis is None:
      basis_fn = lsm_v2.make_polynomial_basis(2)
    else:
      basis_fn = lsm_basis

    batch_shape = exercise_times.shape.as_list()[:-1] or [1]
    unique_exercise_times, exercise_time_index = tf.unique(
        tf.reshape(exercise_times, shape=[-1]))
    exercise_time_index = tf.reshape(
        exercise_time_index, shape=exercise_times.shape)

    # Add a dimension corresponding to multiple cashflows in a swap
    if exercise_times.shape.rank == fixed_leg_payment_times.shape.rank - 1:
      exercise_times = tf.expand_dims(exercise_times, axis=-1)
    elif exercise_times.shape.rank < fixed_leg_payment_times.shape.rank - 1:
      raise ValueError('Swaption exercise times not specified for all '
                       'swaptions in the batch. Expected rank '
                       '{} but received {}.'.format(
                           fixed_leg_payment_times.shape.rank - 1,
                           exercise_times.shape.rank))

    exercise_times = tf.repeat(
        exercise_times, fixed_leg_payment_times.shape.as_list()[-1], axis=-1)

    # Monte-Carlo pricing
    model = vector_hull_white.VectorHullWhiteModel(
        dim,
        mean_reversion,
        volatility,
        initial_discount_rate_fn=reference_rate_fn,
        dtype=dtype)

    if time_step is None:
      raise ValueError('`time_step` must be provided for LSM valuation.')

    sim_times = unique_exercise_times
    longest_exercise_time = sim_times[-1]
    sim_times, _ = tf.unique(tf.concat([sim_times, tf.range(
        time_step, longest_exercise_time, time_step)], axis=0))
    sim_times = tf.sort(sim_times, name='sort_sim_times')

    maturities = fixed_leg_payment_times
    maturities_shape = maturities.shape
    tau = maturities - exercise_times

    curve_times_builder, _ = tf.unique(tf.reshape(tau, shape=[-1]))
    curve_times = tf.sort(curve_times_builder, name='sort_curve_times')

    # Simulate short rates and discount factors.
    p_t_tau, r_t = model.sample_discount_curve_paths(
        times=sim_times,
        curve_times=curve_times,
        num_samples=num_samples,
        random_type=random_type,
        seed=seed,
        skip=skip)

    dt = tf.concat(
        [tf.convert_to_tensor([0.0], dtype=dtype),
         sim_times[1:] - sim_times[:-1]], axis=0)
    dt = tf.expand_dims(tf.expand_dims(dt, axis=-1), axis=0)
    discount_factors_builder = tf.math.exp(-r_t * dt)
    # Transpose before (and after) because we want the cumprod along axis=1
    # and `matvec` operates on the last axis.
    discount_factors_builder = tf.transpose(
        utils.cumprod_using_matvec(
            tf.transpose(discount_factors_builder, [0, 2, 1])), [0, 2, 1])

    # make discount factors the same shape as `p_t_tau`. This involves adding
    # an extra dimenstion (corresponding to `curve_times`).
    discount_factors_builder = tf.expand_dims(
        discount_factors_builder,
        axis=1)
    # tf.repeat is needed because we will use gather_nd later on this tensor.
    discount_factors_simulated = tf.repeat(
        discount_factors_builder, p_t_tau.shape.as_list()[1], axis=1)

    # `sim_times` and `curve_times` are sorted for simulation. We need to
    # select the indices corresponding to our input.
    sim_time_index = tf.searchsorted(
        sim_times, tf.reshape(exercise_times, [-1]))
    curve_time_index = tf.searchsorted(curve_times, tf.reshape(tau, [-1]))

    gather_index = _prepare_indices(
        tf.range(0, num_samples), curve_time_index, sim_time_index,
        tf.range(0, dim))

    # TODO(b/167421126): Replace `tf.gather_nd` with `tf.gather`.
    payoff_bond_price_builder = tf.gather_nd(p_t_tau, gather_index)
    payoff_bond_price = tf.reshape(
        payoff_bond_price_builder, [num_samples] + maturities_shape + [dim])

    # Add an axis corresponding to `dim`
    fixed_leg_pv = tf.expand_dims(
        fixed_leg_coupon * fixed_leg_daycount_fractions,
        axis=-1) * payoff_bond_price
    # Sum fixed coupon payments within each swap to calculate the swap payoff
    # at each exercise time.
    fixed_leg_pv = tf.math.reduce_sum(fixed_leg_pv, axis=-2)
    float_leg_pv = 1.0 - payoff_bond_price[..., -1, :]
    payoff_swap = float_leg_pv - fixed_leg_pv
    payoff_swap = tf.where(is_payer_swaption, payoff_swap, -1.0 * payoff_swap)

    # Get the short rate simulations for the set of unique exercise times
    sim_time_index = tf.searchsorted(sim_times, unique_exercise_times)
    short_rate = tf.gather(r_t, sim_time_index, axis=1)

    # Currently the payoffs are computed on exercise times of each option.
    # They need to be mapped to the short rate simulation times, which is a
    # union of all exercise times.
    is_exercise_time, payoff_swap = _map_payoff_to_sim_times(
        exercise_time_index, payoff_swap, num_samples)

    # Transpose so that `time_index` is the leading dimension
    # (for XLA compatibility)
    perm = [is_exercise_time.shape.rank - 1] + list(
        range(is_exercise_time.shape.rank - 1))
    is_exercise_time = tf.transpose(is_exercise_time, perm=perm)
    payoff_swap = tf.transpose(payoff_swap, perm=perm)

    # Time to call LSM
    def _payoff_fn(rt, time_index):
      del rt
      result = tf.where(is_exercise_time[time_index] > 0,
                        tf.nn.relu(payoff_swap[time_index]), 0.0)
      return tf.reshape(result, shape=[num_samples] + batch_shape)

    discount_factors_simulated = tf.gather(
        discount_factors_simulated, sim_time_index, axis=2)

    option_value = lsm_v2.least_square_mc(
        short_rate, tf.range(0, tf.shape(short_rate)[1]),
        _payoff_fn,
        basis_fn,
        discount_factors=discount_factors_simulated[:, -1:, :, 0],
        dtype=dtype)

    return notional * option_value
Exemple #30
0
def bs_lsm_price(spots: types.FloatTensor,
                 expiry_times: types.FloatTensor,
                 strikes: types.FloatTensor,
                 volatility: types.FloatTensor,
                 discount_factors: types.FloatTensor,
                 num_samples: int = 100000,
                 num_exercise_times: int = 100,
                 basis_fn=None,
                 seed: Tuple[int, int] = (1, 2),
                 is_call_option: types.BoolTensor = True,
                 num_calibration_samples: int = None,
                 dtype: types.Dtype = None,
                 name: str = None):
    """Computes American option price via LSM under Black-Scholes model.

  Args:
    spots: A rank 1 real `Tensor` with spot prices.
    expiry_times: A `Tensor` of the same shape and dtype as `spots` representing
      expiry values of the options.
    strikes: A `Tensor` of the same shape and dtype as `spots` representing
      strike price of the options.
    volatility: A `Tensor` of the same shape and dtype as `spots` representing
      volatility values of the options.
    discount_factors: A `Tensor` of the same shape and dtype as `spots`
      representing discount factors at the expiry times.
    num_samples: Number of Monte Carlo samples.
    num_exercise_times: Number of excercise times for American options.
    basis_fn: Callable from a `Tensor` of the same shape
      `[num_samples, num_exercice_times, 1]` (corresponding to Monte Carlo
      samples) and a positive integer `Tenor` (representing a current
      time index) to a `Tensor` of shape `[basis_size, num_samples]` of the same
      dtype as `spots`. The result being the design matrix used in
      regression of the continuation value of options.
      This is the same argument as in `lsm_algorithm.least_square_mc`.
    seed: A tuple of 2 integers setting global and local seed of the Monte Carlo
      sampler
    is_call_option: A bool `Tensor`.
    num_calibration_samples: An optional integer less or equal to `num_samples`.
      The number of sampled trajectories used for the LSM regression step.
      Default value: `None`, which means that all samples are used for
        regression.
    dtype: `tf.Dtype` of the input and output real `Tensor`s.
      Default value: `None` which maps to `float64`.
    name: Python str. The name to give to the ops created by this class.
      Default value: `None` which maps to 'forward_rate_agreement'.
  Returns:
    A `Tensor` of the same shape and dtyoe as `spots` representing american
    option prices.
  """
    dtype = dtype or tf.float64
    name = name or "bs_lsm_price"
    with tf.name_scope(name):

        strikes = tf.convert_to_tensor(strikes, dtype=dtype, name="strikes")
        spots = tf.convert_to_tensor(spots, dtype=dtype, name="spots")
        volatility = tf.convert_to_tensor(volatility,
                                          dtype=dtype,
                                          name="volatility")
        expiry_times = tf.convert_to_tensor(expiry_times,
                                            dtype=dtype,
                                            name="expiry_times")
        discount_factors = tf.convert_to_tensor(discount_factors,
                                                dtype=dtype,
                                                name="discount_factors")
        risk_free_rate = -tf.math.log(discount_factors) / expiry_times
        # Normalize expiry times
        var = volatility**2
        expiry_times = expiry_times * var

        gbm = models.GeometricBrownianMotion(mu=0.0, sigma=1.0, dtype=dtype)
        max_time = tf.reduce_max(expiry_times)

        # Get a grid of 100 exercise times + all expiry times
        times = tf.sort(
            tf.concat([
                tf.linspace(tf.constant(0.0, dtype), max_time,
                            num_exercise_times), expiry_times
            ],
                      axis=0))
        # Samples for all options
        samples = gbm.sample_paths(
            times,
            initial_state=1.0,
            num_samples=num_samples,
            seed=seed,
            random_type=math.random.RandomType.STATELESS_ANTITHETIC)
        indices = tf.searchsorted(times, expiry_times)
        indices_ext = tf.expand_dims(indices, axis=-1)

        # Payoff function takes all the samples of shape
        # [num_paths, num_times, dim] and returns a `Tensor` of
        # shape [num_paths, num_strikes]. This corresponds to a
        # payoff at the present time.
        def _payoff_fn(sample_paths, time_index):
            current_samples = tf.transpose(sample_paths, [1, 2, 0])[time_index]
            r = tf.math.exp(
                tf.expand_dims(risk_free_rate / var, axis=-1) *
                times[time_index])
            s = tf.expand_dims(spots, axis=-1)
            call_put = tf.expand_dims(is_call_option, axis=-1)
            payoff = tf.expand_dims(strikes, -1) - r * s * current_samples
            payoff = tf.where(call_put, tf.nn.relu(-payoff),
                              tf.nn.relu(payoff))
            # Since the pricing is happening on the grid,
            # For options, which have already expired, the payoff is set to `0`
            # to indicate that one should not exercise the option after it has expired
            res = tf.where(time_index > indices_ext, tf.constant(0,
                                                                 dtype=dtype),
                           payoff)
            return tf.transpose(res)

        if basis_fn is None:
            # Polynomial basis with 2 functions
            basis_fn = lsm_algorithm.make_polynomial_basis_v2(2)

        # Set up Longstaff-Schwartz algorithm
        def lsm_price(sample_paths):
            exercise_times = tf.range(tf.shape(times)[0])
            # This is Longstaff-Schwartz algorithm
            return lsm_algorithm.least_square_mc_v2(
                sample_paths=sample_paths,
                exercise_times=exercise_times,
                payoff_fn=_payoff_fn,
                basis_fn=basis_fn,
                discount_factors=tf.math.exp(
                    -tf.reshape(risk_free_rate / var, [1, -1, 1]) * times),
                num_calibration_samples=num_calibration_samples)

        return lsm_price(samples)