def get_discount_curve(discount_curve_types: List[ Union[curve_types_lib.RiskFreeCurve, curve_types_lib.RateIndexCurve]], market: pmd.ProcessedMarketData, mask: List[int]) -> rate_curve.RateCurve: """Builds a batched discount curve. Given a list of discount curve an integer mask, creates a discount curve object to compute discount factors against the list of discount curves. #### Example ```none curve_types = [RiskFreeCurve("USD"), RiskFreeCurve("AUD")] # A mask to price a batch of 7 instruments with the corresponding discount # curves ["USD", "AUD", "AUD", "AUD" "USD", "USD", "AUD"]. mask = [0, 1, 1, 1, 0, 0, 1] market = MarketDataDict(...) get_discount_curve(curve_types, market, mask) # Returns a RateCurve object that can compute a discount factors for a # batch of 7 dates. ``` Args: discount_curve_types: A list of curve types. market: an instance of the processed market data. mask: An integer mask. Returns: An instance of `RateCurve`. """ discount_curves = [ market.yield_curve(curve_type) for curve_type in discount_curve_types ] discounts = [] dates = [] interpolation_method = None interpolate_rates = None for curve in discount_curves: discount, date = curve.discount_factors_and_dates() discounts.append(discount) dates.append(date) interpolation_method = curve.interpolation_method interpolate_rates = curve.interpolate_rates all_discounts = tf.stack(pad.pad_tensors(discounts), axis=0) all_dates = pad.pad_date_tensors(dates) all_dates = dateslib.DateTensor.stack(dates, axis=0) prepare_discounts = tf.gather(all_discounts, mask) prepare_dates = dateslib.dates_from_ordinals( tf.gather(all_dates.ordinal(), mask)) # All curves are assumed to have the same interpolation method # TODO(b/168411153): Extend to the case with multiple curve configs. discount_curve = rate_curve.RateCurve(prepare_dates, prepare_discounts, market.date, interpolator=interpolation_method, interpolate_rates=interpolate_rates) return discount_curve
def _get_fixings(start_dates, end_dates, reference_curve_types, reference_mask, market): """Computes fixings for a list of reference curves.""" num_curves = len(reference_curve_types) if num_curves > 1: # For each curve get corresponding cashflow indices split_indices = [ tf.squeeze(tf.where(tf.equal(reference_mask, i)), -1) for i in range(num_curves) ] else: split_indices = [0] fixings = [] start_dates_ordinal = start_dates.ordinal() end_dates_ordinal = end_dates.ordinal() for idx, reference_curve_type in zip(split_indices, reference_curve_types): if num_curves > 1: # Get all dates corresponding to the reference curve start_date = dateslib.dates_from_ordinals( tf.gather(start_dates_ordinal, idx)) end_date = dateslib.dates_from_ordinals( tf.gather(end_dates_ordinal, idx)) else: start_date = start_dates end_date = end_dates fixing, fixing_daycount = market.fixings(start_date, reference_curve_type) if fixing_daycount is not None: fixing_daycount = market_data_utils.get_daycount_fn( fixing_daycount, dtype=market.dtype) year_fraction = fixing_daycount(start_date=start_date, end_date=end_date) else: year_fraction = 0.0 fixings.append(fixing * year_fraction) fixings = pad.pad_tensors(fixings) all_indices = tf.concat(split_indices, axis=0) all_fixings = tf.concat(fixings, axis=0) if num_curves > 1: return tf.gather(all_fixings, tf.argsort(all_indices)) else: return all_fixings
def _get_fixings(start_dates, reference_curve_types, reset_frequencies, reference_mask, market): """Computes fixings for a list of reference curves.""" split_indices = [ tf.squeeze(tf.where(tf.equal(reference_mask, i)), -1) for i in range(len(reference_curve_types)) ] fixings = [] for idx, reference_curve_type in zip(split_indices, reference_curve_types): start_date = dateslib.dates_from_ordinals( tf.gather(start_dates.ordinal(), idx)) reset_quant = reset_frequencies.quantity() # Do not use gather, if only one reset frequency is supplied if reset_quant.shape.rank > 1: reset_quant = tf.gather(reset_quant, idx) fixings.append( market.fixings(start_date, reference_curve_type, reset_quant)) fixings = pad.pad_tensors(fixings) all_indices = tf.concat(split_indices, axis=0) all_fixings = tf.concat(fixings, axis=0) return tf.gather(all_fixings, tf.argsort(all_indices))