Example #1
0
def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None):
    """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

    with tf.name_scope(name or 'lu_solve'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
        rhs = tf.convert_to_tensor(rhs,
                                   dtype_hint=lower_upper.dtype,
                                   name='rhs')

        assertions = _lu_solve_assertions(lower_upper, perm, rhs,
                                          validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)
                rhs = tf.identity(rhs)

        if (tensorshape_util.rank(rhs.shape) == 2
                and tensorshape_util.rank(perm.shape) == 1):
            # Both rhs and perm have scalar batch_shape.
            permuted_rhs = tf.gather(rhs, perm, axis=-2)
        else:
            # Either rhs or perm have non-scalar batch_shape or we can't determine
            # this information statically.
            rhs_shape = tf.shape(rhs)
            broadcast_batch_shape = tf.broadcast_dynamic_shape(
                rhs_shape[:-2],
                tf.shape(perm)[:-1])
            d, m = rhs_shape[-2], rhs_shape[-1]
            rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]],
                                            axis=0)

            # Tile out rhs.
            broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
            broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

            # Tile out perm and add batch indices.
            broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
            broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
            broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
            broadcast_batch_indices = tf.broadcast_to(
                tf.range(broadcast_batch_size)[:, tf.newaxis],
                [broadcast_batch_size, d])
            broadcast_perm = tf.stack(
                [broadcast_batch_indices, broadcast_perm], axis=-1)

            permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
            permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
        return tf.linalg.triangular_solve(
            lower_upper,  # Only upper is accessed.
            tf.linalg.triangular_solve(lower, permuted_rhs),
            lower=False)
Example #2
0
def _coord_grid_to_mesh_grid(coord_grid):
    if len(coord_grid) == 1:
        return tf.expand_dims(coord_grid[0], -1)
    return tf.stack(values=tf.meshgrid(*coord_grid, indexing="ij"), axis=-1)
Example #3
0
    def call(self, x, training=False):
        x_flat = tf.reshape(x, shape=(-1, self.depth))

        # Split each input vector into one segment per head.
        x_flat_split = tf.split(x_flat, self.num_heads, axis=1)
        x_flat = tf.concat(x_flat_split, axis=0)

        if training:
            # Figure out which centroids we want to keep, and which we want to
            # restart.
            n = x_flat.shape[0]
            keep = self.counts * self.k > self.restart_threshold * n
            restart = tf.math.logical_not(keep)

            # Replace centroids to restart with elements from the batch, using samples
            # from a uniform distribution as a fallback in case we need to restart
            # more centroids than we have elements in the batch.
            restart_idx = tf.squeeze(tf.where(restart), -1)
            n_replace = tf.minimum(tf.shape(restart_idx)[0], x_flat.shape[0])
            e_restart = tf.tensor_scatter_nd_update(
                tf.random.uniform([self.k, self.depth // self.num_heads]),
                tf.expand_dims(restart_idx[:n_replace], 1),
                tf.random.shuffle(x_flat)[:n_replace])

            # Compute the values of the centroids we want to keep by dividing the
            # summed vectors by the corresponding counts.
            e = tf.where(
                tf.expand_dims(keep, 1),
                tf.math.divide_no_nan(self.sums,
                                      tf.expand_dims(self.counts, 1)),
                e_restart)

        else:
            # If not training, just use the centroids as is with no restarts.
            e = tf.math.divide_no_nan(self.sums,
                                      tf.expand_dims(self.counts, 1))

        # Compute distance between each input vector and each cluster center.
        distances = (tf.expand_dims(tf.reduce_sum(x_flat**2, axis=1), 1) -
                     2 * tf.matmul(x_flat, tf.transpose(e)) +
                     tf.expand_dims(tf.reduce_sum(e**2, axis=1), 0))

        # Find nearest cluster center for each input vector.
        c = tf.argmin(distances, axis=1)

        # Quantize input vectors with straight-through estimator.
        z = tf.nn.embedding_lookup(e, c)
        z_split = tf.split(z, self.num_heads, axis=0)
        z = tf.concat(z_split, axis=1)
        z = tf.reshape(z, tf.shape(x))
        z = x + tf.stop_gradient(z - x)

        if training:
            # Compute cluster counts and vector sums over the batch.
            oh = tf.one_hot(indices=c, depth=self.k)
            counts = tf.reduce_sum(oh, axis=0)
            sums = tf.matmul(oh, x_flat, transpose_a=True)

            # Apply exponential moving average to cluster counts and vector sums.
            self.counts.assign_sub((1 - self.gamma) * (self.counts - counts))
            self.sums.assign_sub((1 - self.gamma) * (self.sums - sums))

        c_split = tf.split(c, self.num_heads, axis=0)
        c = tf.stack(c_split, axis=1)
        c = tf.reshape(c,
                       tf.concat([tf.shape(x)[:-1], [self.num_heads]], axis=0))

        return z, c
Example #4
0
 def set_negative_scores(scores, indices):
     indices_2d = tf.stack(
         [tf.range(bsz, dtype=indices.dtype), indices], axis=1)
     return tf.tensor_scatter_nd_update(
         scores, indices_2d, tf.fill(tf.shape(indices), -1.0))
Example #5
0
def update_confusion_matrix_variables(variables_to_update,
                                      y_true,
                                      y_pred,
                                      thresholds,
                                      top_k=None,
                                      class_id=None,
                                      sample_weight=None,
                                      multi_label=False,
                                      label_weights=None):
    """Returns op to update the given confusion matrix variables.

  For every pair of values in y_true and y_pred:

  true_positive: y_true == True and y_pred > thresholds
  false_negatives: y_true == True and y_pred <= thresholds
  true_negatives: y_true == False and y_pred <= thresholds
  false_positive: y_true == False and y_pred > thresholds

  The results will be weighted and added together. When multiple thresholds are
  provided, we will repeat the same for every threshold.

  For estimation of these metrics over a stream of data, the function creates an
  `update_op` operation that updates the given variables.

  If `sample_weight` is `None`, weights default to 1.
  Use weights of 0 to mask values.

  Args:
    variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
      and corresponding variables to update as values.
    y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
    y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
      the range `[0, 1]`.
    thresholds: A float value, float tensor, python list, or tuple of float
      thresholds in `[0, 1]`, or NEG_INF (used when top_k is set).
    top_k: Optional int, indicates that the positive labels should be limited to
      the top k predictions.
    class_id: Optional int, limits the prediction and labels to the class
      specified by this argument.
    sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
      `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `y_true` dimension).
    multi_label: Optional boolean indicating whether multidimensional
      prediction/labels should be treated as multilabel responses, or flattened
      into a single label. When True, the valus of `variables_to_update` must
      have a second dimension equal to the number of labels in y_true and
      y_pred, and those tensors must not be RaggedTensors.
    label_weights: (optional) tensor of non-negative weights for multilabel
      data. The weights are applied when calculating TP, FP, FN, and TN without
      explicit multilabel handling (i.e. when the data is to be flattened).

  Returns:
    Update op.

  Raises:
    ValueError: If `y_pred` and `y_true` have mismatched shapes, or if
      `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if
      `variables_to_update` contains invalid keys.
  """
    if multi_label and label_weights is not None:
        raise ValueError(
            '`label_weights` for multilabel data should be handled '
            'outside of `update_confusion_matrix_variables` when '
            '`multi_label` is True.')
    if variables_to_update is None:
        return
    if not any(key
               for key in variables_to_update if key in list(ConfusionMatrix)):
        raise ValueError(
            'Please provide at least one valid confusion matrix '
            'variable to update. Valid variable key options are: "{}". '
            'Received: "{}"'.format(list(ConfusionMatrix),
                                    variables_to_update.keys()))

    variable_dtype = list(variables_to_update.values())[0].dtype

    y_true = tf.cast(y_true, dtype=variable_dtype)
    y_pred = tf.cast(y_pred, dtype=variable_dtype)
    thresholds = tf.convert_to_tensor(thresholds, dtype=variable_dtype)
    num_thresholds = thresholds.shape[0]
    if multi_label:
        one_thresh = tf.equal(tf.cast(1, dtype=tf.int32),
                              tf.rank(thresholds),
                              name='one_set_of_thresholds_cond')
    else:
        [y_pred, y_true
         ], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
                                                             sample_weight)
        one_thresh = tf.cast(True, dtype=tf.bool)

    invalid_keys = [
        key for key in variables_to_update if key not in list(ConfusionMatrix)
    ]
    if invalid_keys:
        raise ValueError(
            'Invalid keys: {}. Valid variable key options are: "{}"'.format(
                invalid_keys, list(ConfusionMatrix)))

    with tf.control_dependencies([
            tf.compat.v1.assert_greater_equal(
                y_pred,
                tf.cast(0.0, dtype=y_pred.dtype),
                message='predictions must be >= 0'),
            tf.compat.v1.assert_less_equal(y_pred,
                                           tf.cast(1.0, dtype=y_pred.dtype),
                                           message='predictions must be <= 1')
    ]):
        if sample_weight is None:
            y_pred, y_true = losses_utils.squeeze_or_expand_dimensions(
                y_pred, y_true)
        else:
            sample_weight = tf.cast(sample_weight, dtype=variable_dtype)
            y_pred, y_true, sample_weight = (
                losses_utils.squeeze_or_expand_dimensions(
                    y_pred, y_true, sample_weight=sample_weight))
    y_pred.shape.assert_is_compatible_with(y_true.shape)

    if top_k is not None:
        y_pred = _filter_top_k(y_pred, top_k)
    if class_id is not None:
        y_true = y_true[..., class_id]
        y_pred = y_pred[..., class_id]

    pred_shape = tf.compat.v1.shape(y_pred)
    num_predictions = pred_shape[0]
    if y_pred.shape.ndims == 1:
        num_labels = 1
    else:
        num_labels = tf.raw_ops.Prod(input=pred_shape[1:], axis=0)
    thresh_label_tile = tf.compat.v1.cond(one_thresh, lambda: num_labels,
                                          lambda: tf.cast(1, dtype=tf.int32))

    # Reshape predictions and labels, adding a dim for thresholding.
    if multi_label:
        predictions_extra_dim = tf.compat.v1.expand_dims(y_pred, 0)
        labels_extra_dim = tf.compat.v1.expand_dims(
            tf.cast(y_true, dtype=tf.bool), 0)
    else:
        # Flatten predictions and labels when not multilabel.
        predictions_extra_dim = tf.reshape(y_pred, [1, -1])
        labels_extra_dim = tf.reshape(tf.cast(y_true, dtype=tf.bool), [1, -1])

    # Tile the thresholds for every prediction.
    if multi_label:
        thresh_pretile_shape = [num_thresholds, 1, -1]
        thresh_tiles = [1, num_predictions, thresh_label_tile]
        data_tiles = [num_thresholds, 1, 1]
    else:
        thresh_pretile_shape = [num_thresholds, -1]
        thresh_tiles = [1, num_predictions * num_labels]
        data_tiles = [num_thresholds, 1]

    thresh_tiled = tf.tile(tf.reshape(thresholds, thresh_pretile_shape),
                           tf.stack(thresh_tiles))

    # Tile the predictions for every threshold.
    preds_tiled = tf.tile(predictions_extra_dim, data_tiles)

    # Compare predictions and threshold.
    pred_is_pos = tf.greater(preds_tiled, thresh_tiled)

    # Tile labels by number of thresholds
    label_is_pos = tf.tile(labels_extra_dim, data_tiles)

    if sample_weight is not None:
        sample_weight = tf.__internal__.ops.broadcast_weights(
            tf.cast(sample_weight, dtype=variable_dtype), y_pred)
        weights_tiled = tf.tile(tf.reshape(sample_weight, thresh_tiles),
                                data_tiles)
    else:
        weights_tiled = None

    if label_weights is not None and not multi_label:
        label_weights = tf.compat.v1.expand_dims(label_weights, 0)
        label_weights = tf.__internal__.ops.broadcast_weights(
            label_weights, y_pred)
        label_weights_tiled = tf.tile(tf.reshape(label_weights, thresh_tiles),
                                      data_tiles)
        if weights_tiled is None:
            weights_tiled = label_weights_tiled
        else:
            weights_tiled = tf.multiply(weights_tiled, label_weights_tiled)

    update_ops = []

    def weighted_assign_add(label, pred, weights, var):
        label_and_pred = tf.cast(tf.logical_and(label, pred), dtype=var.dtype)
        if weights is not None:
            label_and_pred *= tf.cast(weights, dtype=var.dtype)
        return var.assign_add(tf.reduce_sum(label_and_pred, 1))

    loop_vars = {
        ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),
    }
    update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
    update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update
    update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update

    if update_fn or update_tn:
        pred_is_neg = tf.logical_not(pred_is_pos)
        loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos,
                                                      pred_is_neg)

    if update_fp or update_tn:
        label_is_neg = tf.logical_not(label_is_pos)
        loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg,
                                                      pred_is_pos)
        if update_tn:
            loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg,
                                                         pred_is_neg)

    for matrix_cond, (label, pred) in loop_vars.items():

        if matrix_cond in variables_to_update:
            update_ops.append(
                weighted_assign_add(label, pred, weights_tiled,
                                    variables_to_update[matrix_cond]))

    return tf.group(update_ops)
Example #6
0
    def call(self, inputs):
        if self.conditional_inputs is None and self.conditional_outputs is None:
            covariance_matrix = self.covariance_fn(inputs, inputs)
            # Tile locations so output has shape [units, batch_size]. Covariance will
            # broadcast to [units, batch_size, batch_size], and we perform
            # shape manipulations to get a random variable over [batch_size, units].
            loc = self.mean_fn(inputs)
            loc = tf.tile(loc[tf.newaxis], [self.units] + [1] * len(loc.shape))
        else:
            knn = self.covariance_fn(inputs, inputs)
            knm = self.covariance_fn(inputs, self.conditional_inputs)
            kmm = self.covariance_fn(self.conditional_inputs,
                                     self.conditional_inputs)
            kmm = tf.linalg.set_diag(
                kmm,
                tf.linalg.diag_part(kmm) + tf.keras.backend.epsilon())
            kmm_tril = tf.linalg.cholesky(kmm)
            kmm_tril_operator = tf.linalg.LinearOperatorLowerTriangular(
                kmm_tril)
            knm_operator = tf.linalg.LinearOperatorFullMatrix(knm)

            # TODO(trandustin): Vectorize linear algebra for multiple outputs. For
            # now, we do each separately and stack to obtain a locations Tensor of
            # shape [units, batch_size].
            loc = []
            for conditional_outputs_unit in tf.unstack(
                    self.conditional_outputs, axis=-1):
                center = conditional_outputs_unit - self.mean_fn(
                    self.conditional_inputs)
                loc_unit = knm_operator.matvec(
                    kmm_tril_operator.solvevec(
                        kmm_tril_operator.solvevec(center), adjoint=True))
                loc.append(loc_unit)
            loc = tf.stack(loc) + self.mean_fn(inputs)[tf.newaxis]

            covariance_matrix = knn
            covariance_matrix -= knm_operator.matmul(
                kmm_tril_operator.solve(kmm_tril_operator.solve(
                    knm, adjoint_arg=True),
                                        adjoint=True))

        covariance_matrix = tf.linalg.set_diag(
            covariance_matrix,
            tf.linalg.diag_part(covariance_matrix) +
            tf.keras.backend.epsilon())

        # Form a multivariate normal random variable with batch_shape units and
        # event_shape batch_size. Then make it be independent across the units
        # dimension. Then transpose its dimensions so it is [batch_size, units].
        random_variable = (
            generated_random_variables.MultivariateNormalFullCovariance(
                loc=loc, covariance_matrix=covariance_matrix))
        random_variable = generated_random_variables.Independent(
            random_variable.distribution, reinterpreted_batch_ndims=1)
        bijector = tfp.bijectors.Inline(
            forward_fn=lambda x: tf.transpose(x, perm=[1, 0]),
            inverse_fn=lambda y: tf.transpose(y, perm=[1, 0]),
            forward_event_shape_fn=lambda input_shape: input_shape[::-1],
            forward_event_shape_tensor_fn=lambda input_shape: input_shape[::-1
                                                                          ],
            inverse_log_det_jacobian_fn=lambda y: tf.cast(0, y.dtype),
            forward_min_event_ndims=2)
        random_variable = generated_random_variables.TransformedDistribution(
            random_variable.distribution, bijector=bijector)
        return random_variable
Example #7
0
    def compute_logits(self, batch_features: tf.Tensor,
                       target_token_seq: tf.Tensor,
                       training: bool) -> tf.Tensor:
        """
        Implements a language model, where each output is conditional on the current
        input and inputs processed so far.

        Args:
            token_ids: int32 tensor of shape [B, T], storing integer IDs of tokens.
            training: Flag indicating if we are currently training (used to toggle dropout)

        Returns:
            tf.float32 tensor of shape [B, T, V], storing the distribution over output symbols
            for each timestep for each batch element.
        """
        num_graphs = tf.cast(batch_features["num_graphs_in_batch"], tf.float32)

        if self.hyperparameters["encoder_type"] == "seq":
            enc_hidden = self.seq_encoder.initialize_hidden_state(
                batch_features["num_graphs_in_batch"])
            enc_output, enc_hidden = self.seq_encoder(
                batch_features["source_seq"], enc_hidden)

            dec_hidden = enc_hidden
            dec_input = tf.expand_dims(
                [self.vocab_target.get_id_or_unk("%START%")] *
                batch_features["num_graphs_in_batch"], 1)

        elif self.hyperparameters["encoder_type"] == "graph":
            enc_hidden, enc_output = self.graph_encoder(batch_features,
                                                        training=training)

            enc_output = tf.split(enc_output,
                                  batch_features["graph_to_num_nodes"])
            enc_output = [
                tf.concat([
                    out,
                    tf.zeros([
                        self.hyperparameters["max_node_num"] - out.shape[0],
                        self.hyperparameters["token_embedding_size"]
                    ])
                ], 0) for out in enc_output
            ]
            enc_output = tf.stack(enc_output)

            dec_hidden_h = self.linear_h(enc_hidden)
            dec_hidden_c = self.linear_c(enc_hidden)

            if self.hyperparameters["rnn_cell"] == "LSTM":
                dec_hidden = [dec_hidden_h, dec_hidden_c]
            else:
                dec_hidden = dec_hidden_h
            dec_input = tf.expand_dims(
                [self.vocab_target.get_id_or_unk("%START%")] *
                enc_hidden.shape[0], 1)

        elif self.hyperparameters["encoder_type"] == "graph+seq":
            enc_hidden, enc_output = self.graph_encoder(batch_features,
                                                        training=training)

            enc_output = tf.split(enc_output,
                                  batch_features["graph_to_num_nodes"])
            enc_output = [
                tf.concat([
                    out[:sorce_len, :],
                    tf.zeros([
                        200 - out[:sorce_len, :].shape[0],
                        self.hyperparameters["token_embedding_size"]
                    ])
                ], 0) for out, sorce_len in zip(enc_output,
                                                batch_features["source_len"])
            ]
            enc_output = tf.stack(enc_output)

            hidden_h = self.linear_h(enc_hidden)
            hidden_c = self.linear_c(enc_hidden)

            if self.hyperparameters["rnn_cell"] == "LSTM":
                hidden = [hidden_h, hidden_c]
            else:
                hidden = hidden_h

            enc_output, dec_hidden = self.seq_encoder(enc_output,
                                                      hidden,
                                                      embedd=False)

            dec_input = tf.expand_dims(
                [self.vocab_target.get_id_or_unk("%START%")] *
                enc_hidden.shape[0], 1)

        elif self.hyperparameters["encoder_type"] == "seq+graph":
            enc_hidden = self.seq_encoder.initialize_hidden_state(
                batch_features["num_graphs_in_batch"])
            seq_enc_output, seq_enc_hidden = self.seq_encoder(
                batch_features["source_seq"], enc_hidden)

            enc_hidden, enc_output = self.graph_encoder(
                batch_features,
                training=training,
                seq_enc_output=seq_enc_output)

            enc_output = tf.split(enc_output,
                                  batch_features["graph_to_num_nodes"])
            enc_output = [
                tf.concat([
                    out,
                    tf.zeros([
                        self.hyperparameters["max_node_num"] - out.shape[0],
                        self.hyperparameters["token_embedding_size"]
                    ])
                ], 0) for out in enc_output
            ]
            enc_output = tf.stack(enc_output)

            hidden_h = self.linear_h(enc_hidden)

            if self.hyperparameters["rnn_cell"] == "LSTM":
                dec_hidden = [hidden_h, seq_enc_hidden[1]]
            else:
                dec_hidden = hidden_h

            dec_input = tf.expand_dims(
                [self.vocab_target.get_id_or_unk("%START%")] *
                enc_hidden.shape[0], 1)

        if training and random.random() > 0.5:
            # Use teacher forcing
            predictions, dec_hidden = self.decoder(target_token_seq[:, :-1],
                                                   dec_hidden, enc_output)
            return predictions
        else:
            # The predicted ID is fed back into the model
            for t in range(1, self.hyperparameters["max_seq_length"]):
                predictions, dec_hidden = self.decoder(dec_input, dec_hidden,
                                                       enc_output)
                predicted_ids = tf.argmax(predictions[:, 0, :], 1)
                dec_input = tf.expand_dims(predicted_ids, 1)
                new_logits = tf.expand_dims(predictions[:, 0, :], 1)
                if t == 1:
                    results = new_logits
                else:
                    results = tf.concat([results, new_logits], 1)

        return results
Example #8
0
    def __init__(self,
                 num_timesteps,
                 coefficients,
                 level_scale,
                 initial_state_prior,
                 observation_noise_scale=0.,
                 name=None,
                 **linear_gaussian_ssm_kwargs):
        """Build a state space model implementing an autoregressive process.

    Args:
      num_timesteps: Scalar `int` `Tensor` number of timesteps to model
        with this distribution.
      coefficients: `float` `Tensor` of shape `concat(batch_shape, [order])`
        defining  the autoregressive coefficients. The coefficients are defined
        backwards in time: `coefficients[0] * level[t] + coefficients[1] *
        level[t-1] + ... + coefficients[order-1] * level[t-order+1]`.
      level_scale: Scalar (any additional dimensions are treated as batch
        dimensions) `float` `Tensor` indicating the standard deviation of the
        transition noise at each step.
      initial_state_prior: instance of `tfd.MultivariateNormal`
        representing the prior distribution on latent states.  Must have
        event shape `[order]`.
      observation_noise_scale: Scalar (any additional dimensions are
        treated as batch dimensions) `float` `Tensor` indicating the standard
        deviation of the observation noise.
        Default value: 0.
      name: Python `str` name prefixed to ops created by this class.
        Default value: "AutoregressiveStateSpaceModel".
      **linear_gaussian_ssm_kwargs: Optional additional keyword arguments to
        to the base `tfd.LinearGaussianStateSpaceModel` constructor.
    """
        parameters = dict(locals())
        parameters.update(linear_gaussian_ssm_kwargs)
        del parameters['linear_gaussian_ssm_kwargs']
        with tf.name_scope(name or 'AutoregressiveStateSpaceModel') as name:

            # The initial state prior determines the dtype of sampled values.
            # Other model parameters must have the same dtype.
            dtype = initial_state_prior.dtype

            coefficients = tf.convert_to_tensor(value=coefficients,
                                                name='coefficients',
                                                dtype=dtype)
            level_scale = tf.convert_to_tensor(value=level_scale,
                                               name='level_scale',
                                               dtype=dtype)
            observation_noise_scale = tf.convert_to_tensor(
                value=observation_noise_scale,
                name='observation_noise_scale',
                dtype=dtype)

            order = tf.compat.dimension_value(coefficients.shape[-1])
            if order is None:
                raise ValueError(
                    'Autoregressive coefficients must have static shape.')

            self._order = order
            self._coefficients = coefficients
            self._level_scale = level_scale

            super(AutoregressiveStateSpaceModel, self).__init__(
                num_timesteps=num_timesteps,
                transition_matrix=make_ar_transition_matrix(coefficients),
                transition_noise=tfd.MultivariateNormalDiag(
                    scale_diag=tf.stack([level_scale] +
                                        [tf.zeros_like(level_scale)] *
                                        (self.order - 1),
                                        axis=-1)),
                observation_matrix=tf.concat([
                    tf.ones([1, 1], dtype=dtype),
                    tf.zeros([1, self.order - 1], dtype=dtype)
                ],
                                             axis=-1),
                observation_noise=tfd.MultivariateNormalDiag(
                    scale_diag=observation_noise_scale[..., tf.newaxis]),
                initial_state_prior=initial_state_prior,
                name=name,
                **linear_gaussian_ssm_kwargs)
            self._parameters = parameters
Example #9
0
def softquantiles(x,
                  quantiles,
                  quantile_width=None,
                  axis=-1,
                  may_squeeze=True,
                  **kwargs):
    """Computes soft quantiles via optimal transport.

  This operator takes advantage of the fact that an exhaustive softsort is not
  required to recover a single quantile. Instead, one can transport all
  input values in x onto only 3 weighted values. Target weights are adjusted so
  that those values in x that are transported to the middle value in the target
  vector y correspond to those concentrating around the quantile of interest.

  This idea generalizes to more quantiles, interleaving small weights on the
  quantile indices and bigger weights in between, corresponding to the gap from
  one desired quantile to the next one.

  Args:
   x: Tensor<float> of any shape.
   quantiles: list<float> the quantiles to be returned. It can also be a single
     float.
   quantile_width: (float) mass given to the bucket supposed to attract points
     whose value concentrate around the desired quantile value. Bigger width
     means that we allow the soft quantile to be a mixture of more points
     further away from the quantile. If None, the width is set at 1/n where n is
     the number of values considered (the size along the 'axis').
   axis: (int) the axis along which to compute the quantile.
   may_squeeze: (bool) should we squeeze the output tensor in case of a single
     quantile.
   **kwargs: see SoftQuantilizer for possible extra parameters.

  Returns:
    A Tensor<float> similar to the input tensor, but the axis dimension is
    replaced by the number of quantiles specified in the quantiles list.
    Hence, if only a quantile is requested (quantiles is a float) only one value
    in that axis is returned. When several quantiles are requested, the tensor
    will have that many values in that axis.

  Raises:
    tf.errors.InvalidArgumentError when the quantiles and quantile width are not
    correct, namely quantiles are either not in sorted order or the
    quantile_width is too large.
  """
    if isinstance(quantiles, float):
        quantiles = [quantiles]
    quantiles = tf.constant(quantiles, tf.float32)

    # Preprocesses submitted quantiles to check that they satisfy elementary
    # constraints.
    valid_quantiles = tf.boolean_mask(
        quantiles, tf.logical_and(quantiles > 0.0, quantiles < 1.0))
    num_quantiles = tf.shape(valid_quantiles)[0]

    # Includes values on both ends of [0,1].
    extended_quantiles = tf.concat([[0.0], valid_quantiles, [1.0]], axis=0)

    # Builds filler_weights in between the target quantiles.
    filler_weights = extended_quantiles[1:] - extended_quantiles[:-1]
    if quantile_width is None:
        quantile_width = tf.reduce_min(
            tf.concat([
                filler_weights,
                [1.0 / tf.cast(tf.shape(x)[axis], dtype=x.dtype)]
            ],
                      axis=0))

    # Takes into account quantile_width in the definition of weights
    shift = -tf.ones(tf.shape(filler_weights), dtype=x.dtype)
    shift = shift + 0.5 * (tf.one_hot(0, num_quantiles + 1) +
                           tf.one_hot(num_quantiles, num_quantiles + 1))
    filler_weights = filler_weights + quantile_width * shift

    assert_op = tf.Assert(tf.reduce_all(filler_weights >= 0.0),
                          [filler_weights])
    with tf.control_dependencies([assert_op]):
        # Adds one more value to have tensors of the same shape to interleave them.
        quantile_weights = tf.ones(num_quantiles + 1) * quantile_width

        # Interleaves the filler_weights with the quantile weights.
        weights = tf.reshape(
            tf.stack([filler_weights, quantile_weights], axis=1), (-1, ))[:-1]

        # Sends only the positive weights to the softsort operator.
        positive_weights = tf.boolean_mask(weights, weights > 0.0)
        all_quantiles = softsort(x,
                                 direction='ASCENDING',
                                 axis=axis,
                                 target_weights=positive_weights,
                                 **kwargs)

        # Recovers the indices corresponding to the desired quantiles.
        odds = tf.math.floormod(tf.range(weights.shape[0], dtype=tf.float32),
                                2)
        positives = tf.cast(weights > 0.0, tf.float32)
        indices = tf.cast(tf.math.cumsum(positives) * odds, dtype=tf.int32)
        indices = tf.boolean_mask(indices, indices > 0) - 1
        result = tf.gather(all_quantiles, indices, axis=axis)

        # In the specific case where we want a single quantile, squeezes the
        # quantile dimension.
        can_squeeze = tf.equal(tf.shape(result)[axis], 1)
        if tf.math.logical_and(can_squeeze, may_squeeze):
            result = tf.squeeze(result, axis=axis)
        return result
Example #10
0
    def plot_transfer(self,
                      intervals,
                      figsize=None,
                      soft_round=None,
                      **kwargs):
        if not len(intervals) == self.ndim_source == self.ndim_latent == 1:
            raise ValueError("This method is only defined for 1D models.")
        if soft_round is None:
            soft_round = self.soft_round[1]

        x = [
            tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in intervals
        ]
        x = tf.meshgrid(*x, indexing="ij")
        x = tf.stack(x, axis=-1)

        y_hat, _, _ = self.encode_decode(x, False, False, soft_round, **kwargs)

        y = self.analysis(x)
        # We feed y here so we can visualize the full behavior of the synthesis
        # transform (not just at the quantized latent values).
        x_hat = self.synthesis(y)

        x = np.squeeze(x.numpy(), -1)
        y = np.squeeze(y.numpy(), -1)
        x_hat = np.squeeze(x_hat.numpy(), -1)
        y_hat = np.squeeze(y_hat.numpy(), -1)

        ylim = np.min(y), np.max(y)

        boundaries = np.nonzero(y_hat[1:] != y_hat[:-1])[0]
        lboundaries = (y_hat[boundaries] + y_hat[boundaries + 1]) / 2
        dboundaries = (x[boundaries] + x[boundaries + 1]) / 2

        lcodebook = np.unique(y_hat)
        dcodebook = self.synthesis(lcodebook[:, None]).numpy()
        dcodebook = np.squeeze(dcodebook, -1)
        mask = np.logical_and(ylim[0] < lcodebook, lcodebook < ylim[1])
        lcodebook = lcodebook[mask]
        dcodebook = dcodebook[mask]

        plt.figure(figsize=figsize or (16, 14))
        plt.plot(x, y, label="analysis transform")
        plt.plot(x_hat, y, label="synthesis transform")

        plt.gca().set_aspect("equal", "box")
        # Flip y axis if latent space is reversed.
        if y[0] > y[-1]:
            plt.gca().invert_yaxis()
        plt.xticks(dcodebook)
        plt.yticks(lcodebook)
        plt.grid(False)
        plt.xlabel("source space")
        plt.ylabel("latent space")

        xmin = plt.axis()[0]
        ymin = plt.axis()[2]
        for x, y in zip(dcodebook, lcodebook):
            plt.plot([xmin, x, x], [y, y, ymin], "black", lw=1)
            plt.plot([x], [y],
                     "black",
                     marker="o",
                     ms=5,
                     lw=1,
                     label="codebook" if x == dcodebook[0] else None)
        for x, y in zip(dboundaries, lboundaries):
            plt.plot([xmin, x, x], [y, y, ymin], "black", lw=1, ls=":")
            plt.plot([x], [y],
                     "black",
                     marker="o",
                     ms=3,
                     lw=1,
                     ls=":",
                     label="boundaries" if x == dboundaries[0] else None)

        plt.legend(loc="upper left")
Example #11
0
    def plot_jacobians(self,
                       which,
                       intervals,
                       arrow_intervals,
                       scale=2,
                       figsize=None):
        if not (len(intervals) == len(arrow_intervals) == self.ndim_source ==
                self.ndim_latent == 2):
            raise ValueError("This method is only defined for 2D models.")
        if which not in ("analysis", "synthesis"):
            raise ValueError("`which` must be 'analysis' or 'synthesis'.")

        data = [
            tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in intervals
        ]
        data = tf.meshgrid(*data, indexing="ij")
        data = tf.stack(data, axis=-1)
        data_dist = self.source.prob(data).numpy()

        if which == "analysis":
            arrow_data = [
                tf.linspace(float(i[0]), float(i[1]), int(i[2]))
                for i in arrow_intervals
            ]
            arrow_data = tf.meshgrid(*arrow_data, indexing="ij")
            arrow_data = tf.stack(arrow_data, axis=-1)
            arrow_data = tf.reshape(arrow_data, (-1, arrow_data.shape[-1]))
            with tf.GradientTape(watch_accessed_variables=False) as tape:
                tape.watch(arrow_data)
                arrow_latents = self.analysis(arrow_data)
            # First dimension is batch, second is latent dim, third is source dim.
            jacobian = tape.batch_jacobian(arrow_latents, arrow_data)
            jacobian = tf.linalg.inv(jacobian)
            jacobian = tf.transpose(jacobian, (0, 2, 1))
        else:
            arrow_latents = [
                tf.linspace(float(i[0]), float(i[1]), int(i[2]))
                for i in arrow_intervals
            ]
            arrow_latents = tf.meshgrid(*arrow_latents, indexing="ij")
            arrow_latents = tf.stack(arrow_latents, axis=-1)
            arrow_latents = tf.reshape(arrow_latents,
                                       (-1, arrow_latents.shape[-1]))
            with tf.GradientTape(watch_accessed_variables=False) as tape:
                tape.watch(arrow_latents)
                arrow_data = self.synthesis(arrow_latents)
            jacobian = tape.batch_jacobian(arrow_data, arrow_latents)
            jacobian = tf.transpose(jacobian, (0, 2, 1))

        google_pink = (0xf4 / 255, 0x39 / 255, 0xa0 / 255)
        google_purple = (0xa1 / 255, 0x42 / 255, 0xf4 / 255)

        plt.figure(figsize=figsize or (16, 14))
        plt.imshow(data_dist,
                   vmin=0,
                   vmax=data_dist.max(),
                   origin="lower",
                   extent=(data[0, 0, 1], data[0, -1, 1], data[0, 0,
                                                               0], data[-1, 0,
                                                                        0]))
        plt.quiver(
            arrow_data[:, 1],
            arrow_data[:, 0],
            jacobian[:, 0, 1],
            jacobian[:, 0, 0],
            pivot="tail",
            angles="xy",
            headlength=4,
            headaxislength=4,
            units="dots",
            color=google_pink,
            scale_units="xy",
            scale=scale,
        )
        plt.quiver(
            arrow_data[:, 1],
            arrow_data[:, 0],
            jacobian[:, 1, 1],
            jacobian[:, 1, 0],
            pivot="tail",
            angles="xy",
            headlength=4,
            headaxislength=4,
            units="dots",
            color=google_purple,
            scale_units="xy",
            scale=scale,
        )
        plt.axis("image")
        plt.grid(False)
        plt.xlim(data[0, 0, 1], data[0, -1, 1])
        plt.ylim(data[0, 0, 0], data[-1, 0, 0])
        plt.xlabel("source dimension 1")
        plt.ylabel("source dimension 2")
def _stack(*ts):
    return tf.stack(_conform(ts), axis=-1)
Example #13
0
    def loop_tree_doubling(self, step_size, momentum_state_memory,
                           current_step_meta_info, iter_, initial_step_state,
                           initial_step_metastate):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            batch_shape = prefer_static.shape(
                current_step_meta_info.init_energy)
            direction = tf.cast(tf.random.uniform(shape=batch_shape,
                                                  minval=0,
                                                  maxval=2,
                                                  dtype=tf.int32,
                                                  seed=self._seed_stream()),
                                dtype=tf.bool)

            tree_start_states = tf.nest.map_structure(
                lambda v: tf.where(  # pylint: disable=g-long-lambda
                    _rightmost_expand_to_rank(
                        direction, prefer_static.rank(v[1])), v[1], v[0]),
                initial_step_state)

            directions_expanded = [
                _rightmost_expand_to_rank(direction, prefer_static.rank(state))
                for state in tree_start_states.state
            ]

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(d, ss, -ss)
                    for d, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state, tree_final_states, final_not_divergence,
                continue_tree_final, energy_diff_tree_sum,
                momentum_tree_cumsum, leapfrogs_taken
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                current_step_meta_info,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory)

            last_candidate_state = initial_step_metastate.candidate_state
            tree_weight = candidate_tree_state.weight
            if MULTINOMIAL_SAMPLE:
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-tf.random.uniform(shape=batch_shape,
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(choose_new_state,
                                                  prefer_static.rank(s0)), s0,
                        s1) for s0, s1 in zip(candidate_tree_state.state,
                                              last_candidate_state.state)
                ],
                target=tf.where(
                    _rightmost_expand_to_rank(
                        choose_new_state,
                        prefer_static.rank(candidate_tree_state.target)),
                    candidate_tree_state.target, last_candidate_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(choose_new_state,
                                                  prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            candidate_tree_state.target_grad_parts,
                            last_candidate_state.target_grad_parts)
                ],
                energy=tf.where(
                    _rightmost_expand_to_rank(
                        choose_new_state,
                        prefer_static.rank(candidate_tree_state.target)),
                    candidate_tree_state.energy, last_candidate_state.energy),
                weight=weight_sum)

            # Update left right information of the trajectory, and check trajectory
            # level U turn
            tree_otherend_states = tf.nest.map_structure(
                lambda v: tf.where(  # pylint: disable=g-long-lambda
                    _rightmost_expand_to_rank(
                        direction, prefer_static.rank(v[1])), v[0], v[1]),
                initial_step_state)

            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    tf.stack(
                        [  # pylint: disable=g-complex-comprehension
                            tf.where(
                                _rightmost_expand_to_rank(
                                    direction, prefer_static.rank(l)), r, l),
                            tf.where(
                                _rightmost_expand_to_rank(
                                    direction, prefer_static.rank(l)), l, r),
                        ],
                        axis=0)
                    for l, r in zip(tf.nest.flatten(tree_final_states),
                                    tf.nest.flatten(tree_otherend_states))
                ])

            if GENERALIZED_UTURN:
                state_diff = momentum_tree_cumsum
            else:
                state_diff = [s[1] - s[0] for s in new_step_state.state]

            no_u_turns_trajectory = has_not_u_turn(
                state_diff, [m[0] for m in new_step_state.momentum],
                [m[1] for m in new_step_state.momentum],
                log_prob_rank=len(batch_shape))

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                energy_diff_sum=(energy_diff_tree_sum +
                                 initial_step_metastate.energy_diff_sum),
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, new_step_state, new_step_metastate
Example #14
0
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
    """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

  Returns:
    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
      `lu_reconstruct(*tf.linalg.lu(x))`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x))
  tf.assert_near(x, x_reconstructed)
  # ==> True
  ```

  """
    with tf.name_scope(name or 'lu_reconstruct'):
        lower_upper = tf.convert_to_tensor(lower_upper,
                                           dtype_hint=tf.float32,
                                           name='lower_upper')
        perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')

        assertions = lu_reconstruct_assertions(lower_upper, perm,
                                               validate_args)
        if assertions:
            with tf.control_dependencies(assertions):
                lower_upper = tf.identity(lower_upper)
                perm = tf.identity(perm)

        shape = tf.shape(lower_upper)

        lower = tf.linalg.set_diag(
            tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
            tf.ones(shape[:-1], dtype=lower_upper.dtype))
        upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
        x = tf.matmul(lower, upper)

        if (tensorshape_util.rank(lower_upper.shape) is None
                or tensorshape_util.rank(lower_upper.shape) != 2):
            # We either don't know the batch rank or there are >0 batch dims.
            batch_size = tf.reduce_prod(shape[:-2])
            d = shape[-1]
            x = tf.reshape(x, [batch_size, d, d])
            perm = tf.reshape(perm, [batch_size, d])
            perm = tf.map_fn(tf.math.invert_permutation, perm)
            batch_indices = tf.broadcast_to(
                tf.range(batch_size)[:, tf.newaxis], [batch_size, d])
            x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
            x = tf.reshape(x, shape)
        else:
            x = tf.gather(x, tf.math.invert_permutation(perm))

        tensorshape_util.set_shape(x, lower_upper.shape)
        return x
Example #15
0
  def _sample_n(self, n, seed=None):
    loc, scale, low, high = self._loc_scale_low_high()
    batch_shape = self._batch_shape_tensor(
        loc=loc, scale=scale, low=low, high=high)
    sample_and_batch_shape = tf.concat([[n], batch_shape], 0)
    flat_batch_and_sample_shape = tf.stack([tf.reduce_prod(batch_shape), n])

    # TODO(b/162522020): Use this behavior unconditionally.
    if (tf.executing_eagerly() or
        not control_flow_util.GraphOrParentsInXlaContext(
            tf1.get_default_graph())):
      return tf.random.stateless_parameterized_truncated_normal(
          shape=sample_and_batch_shape,
          means=loc,
          stddevs=scale,
          minvals=low,
          maxvals=high,
          seed=samplers.sanitize_seed(seed))

    # In order to be reparameterizable we sample on the truncated_normal of
    # unit variance and mean and scale (but with the standardized
    # truncation bounds).

    @tf.custom_gradient
    def _std_samples_with_gradients(lower, upper):
      """Standard truncated Normal with gradient support for low, high."""
      # Note: Unlike the convention in TFP, parameterized_truncated_normal
      # returns a tensor with the final dimension being the sample dimension.
      std_samples = random_ops.parameterized_truncated_normal(
          shape=flat_batch_and_sample_shape,
          means=0.0,
          stddevs=1.0,
          minvals=lower,
          maxvals=upper,
          dtype=self.dtype,
          seed=seed)

      def grad(dy):
        """Computes a derivative for the min and max parameters.

        This function implements the derivative wrt the truncation bounds, which
        get blocked by the sampler. We use a custom expression for numerical
        stability instead of automatic differentiation on CDF for implicit
        gradients.

        Args:
          dy: output gradients

        Returns:
           The standard normal samples and the gradients wrt the upper
           bound and lower bound.
        """
        # std_samples has an extra dimension (the sample dimension), expand
        # lower and upper so they broadcast along this dimension.
        # See note above regarding parameterized_truncated_normal, the sample
        # dimension is the final dimension.
        lower_broadcast = lower[..., tf.newaxis]
        upper_broadcast = upper[..., tf.newaxis]

        cdf_samples = ((special_math.ndtr(std_samples) -
                        special_math.ndtr(lower_broadcast)) /
                       (special_math.ndtr(upper_broadcast) -
                        special_math.ndtr(lower_broadcast)))

        # tiny, eps are tolerance parameters to ensure we stay away from giving
        # a zero arg to the log CDF expression.

        tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny
        eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps
        cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps)

        du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) +
                    tf.math.log(cdf_samples))
        dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) +
                    tf.math.log1p(-cdf_samples))

        # Reduce the gradient across the samples
        grad_u = tf.reduce_sum(dy * du, axis=-1)
        grad_l = tf.reduce_sum(dy * dl, axis=-1)
        return [grad_l, grad_u]

      return std_samples, grad

    std_low, std_high = self._standardized_low_and_high(
        low=low, high=high, loc=loc, scale=scale)
    low_high_shp = tf.broadcast_dynamic_shape(
        tf.shape(std_low), tf.shape(std_high))
    std_low = tf.broadcast_to(std_low, low_high_shp)
    std_high = tf.broadcast_to(std_high, low_high_shp)

    std_samples = _std_samples_with_gradients(
        tf.reshape(std_low, [-1]), tf.reshape(std_high, [-1]))

    # The returned shape is [flat_batch x n]
    std_samples = tf.transpose(std_samples, perm=[1, 0])

    std_samples = tf.reshape(std_samples, sample_and_batch_shape)
    return std_samples * scale[tf.newaxis] + loc[tf.newaxis]
Example #16
0
  def __init__(self,
               num_timesteps,
               level_scale,
               slope_scale,
               initial_state_prior,
               observation_noise_scale=0.,
               initial_step=0,
               validate_args=False,
               allow_nan_stats=True,
               name=None):
    """Build a state space model implementing a local linear trend.

    Args:
      num_timesteps: Scalar `int` `Tensor` number of timesteps to model
        with this distribution.
      level_scale: Scalar (any additional dimensions are treated as batch
        dimensions) `float` `Tensor` indicating the standard deviation of the
        level transitions.
      slope_scale: Scalar (any additional dimensions are treated as batch
        dimensions) `float` `Tensor` indicating the standard deviation of the
        slope transitions.
      initial_state_prior: instance of `tfd.MultivariateNormal`
        representing the prior distribution on latent states; must
        have event shape `[2]`.
      observation_noise_scale: Scalar (any additional dimensions are
        treated as batch dimensions) `float` `Tensor` indicating the standard
        deviation of the observation noise.
      initial_step: Optional scalar `int` `Tensor` specifying the starting
        timestep.
        Default value: 0.
      validate_args: Python `bool`. Whether to validate input
        with asserts. If `validate_args` is `False`, and the inputs are
        invalid, correct behavior is not guaranteed.
        Default value: `False`.
      allow_nan_stats: Python `bool`. If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
        Default value: `True`.
      name: Python `str` name prefixed to ops created by this class.
        Default value: "LocalLinearTrendStateSpaceModel".
    """

    with tf1.name_scope(name, 'LocalLinearTrendStateSpaceModel',
                                 [level_scale, slope_scale]) as name:

      # The initial state prior determines the dtype of sampled values.
      # Other model parameters must have the same dtype.
      dtype = initial_state_prior.dtype

      level_scale = tf.convert_to_tensor(
          value=level_scale, name='level_scale', dtype=dtype)
      slope_scale = tf.convert_to_tensor(
          value=slope_scale, name='slope_scale', dtype=dtype)
      observation_noise_scale = tf.convert_to_tensor(
          value=observation_noise_scale,
          name='observation_noise_scale',
          dtype=dtype)

      # Explicitly broadcast all parameters to the same batch shape. This
      # allows us to use `tf.stack` for a compact model specification.
      broadcast_batch_shape = dist_util.get_broadcast_shape(
          level_scale, slope_scale)
      broadcast_ones = tf.ones(broadcast_batch_shape, dtype=dtype)

      self._level_scale = level_scale
      self._slope_scale = slope_scale
      self._observation_noise_scale = observation_noise_scale

      # Construct a linear Gaussian state space model implementing the
      # local linear trend model. See "Mathematical Details" in the
      # class docstring for further explanation.
      super(LocalLinearTrendStateSpaceModel, self).__init__(
          num_timesteps=num_timesteps,
          transition_matrix=tf.constant(
              [[1., 1.], [0., 1.]], dtype=dtype, name='transition_matrix'),
          transition_noise=tfd.MultivariateNormalDiag(
              scale_diag=tf.stack(
                  [level_scale * broadcast_ones, slope_scale * broadcast_ones],
                  axis=-1),
              name='transition_noise'),
          observation_matrix=tf.constant(
              [[1., 0.]], dtype=dtype, name='observation_matrix'),
          observation_noise=tfd.MultivariateNormalDiag(
              scale_diag=observation_noise_scale[..., tf.newaxis],
              name='observation_noise'),
          initial_state_prior=initial_state_prior,
          initial_step=initial_step,
          allow_nan_stats=allow_nan_stats,
          validate_args=validate_args,
          name=name)
Example #17
0
def evaluate(
    env,
    policy,
    num_episodes = 10,
    ctx_length = None,
    embed_training_window = None,
    state_mask_fn = None,  # pylint: disable=g-bare-generic
):
  """Evaluates the policy.

  Args:
    env: Environment to evaluate the policy on.
    policy: Policy to evaluate.
    num_episodes: A number of episodes to average the policy on.
    ctx_length: number of previous steps to compute context from.
    embed_training_window: window size used during embed training.
    state_mask_fn: state masking function for partially obs envs.

  Returns:
    Averaged reward and a total number of steps.
  """
  total_timesteps = 0
  total_returns = 0.0

  def apply_mask(observation):
    if state_mask_fn:
      return tf.convert_to_tensor(state_mask_fn(observation.numpy()))
    return observation

  for _ in range(num_episodes):
    timestep = env.reset()
    if ctx_length:
      states = [apply_mask(timestep.observation) for _ in range(ctx_length)]
      actions = [
          tf.zeros(policy.action_spec.shape)[None, :] for _ in range(ctx_length)
      ]
      rewards = [[0.] for _ in range(ctx_length)]

    latent_action = None
    i = 0
    while not timestep.is_last():
      if embed_training_window and (i % embed_training_window == 0 or
                                    embed_training_window <= 2):
        latent_action = None
      if ctx_length:
        states.append(apply_mask(timestep.observation))
        if len(states) > ctx_length:
          states.pop(0)
          actions.pop(0)
          rewards.pop(0)
        action = policy.act(
            tf.stack(states, axis=1),
            actions=tf.stack(actions, axis=1),
            rewards=tf.stack(rewards, axis=1))
        actions.append(action)
      else:
        if embed_training_window:
          action, latent_action = policy.act(
              apply_mask(timestep.observation), latent_action=latent_action)
        else:
          action = policy.act(apply_mask(timestep.observation))

      timestep = env.step(action)
      if ctx_length:
        rewards.append(timestep.reward)

      total_returns += timestep.reward[0]
      total_timesteps += 1
      i += 1

  return total_returns / num_episodes, total_timesteps / num_episodes
Example #18
0
  def __init__(self,
               level_scale_prior=None,
               slope_scale_prior=None,
               initial_level_prior=None,
               initial_slope_prior=None,
               observed_time_series=None,
               name=None):
    """Specify a local linear trend model.

    Args:
      level_scale_prior: optional `tfd.Distribution` instance specifying a prior
        on the `level_scale` parameter. If `None`, a heuristic default prior is
        constructed based on the provided `observed_time_series`.
        Default value: `None`.
      slope_scale_prior: optional `tfd.Distribution` instance specifying a prior
        on the `slope_scale` parameter. If `None`, a heuristic default prior is
        constructed based on the provided `observed_time_series`.
        Default value: `None`.
      initial_level_prior: optional `tfd.Distribution` instance specifying a
        prior on the initial level. If `None`, a heuristic default prior is
        constructed based on the provided `observed_time_series`.
        Default value: `None`.
      initial_slope_prior: optional `tfd.Distribution` instance specifying a
        prior on the initial slope. If `None`, a heuristic default prior is
        constructed based on the provided `observed_time_series`.
        Default value: `None`.
      observed_time_series: optional `float` `Tensor` of shape
        `batch_shape + [T, 1]` (omitting the trailing unit dimension is also
        supported when `T > 1`), specifying an observed time series.
        Any priors not explicitly set will be given default values according to
        the scale of the observed time series (or batch of time series). May
        optionally be an instance of `tfp.sts.MaskedTimeSeries`, which includes
        a mask `Tensor` to specify timesteps with missing observations.
        Default value: `None`.
      name: the name of this model component.
        Default value: 'LocalLinearTrend'.
    """

    with tf1.name_scope(
        name, 'LocalLinearTrend', values=[observed_time_series]) as name:

      _, observed_stddev, observed_initial = (
          sts_util.empirical_statistics(observed_time_series)
          if observed_time_series is not None else (0., 1., 0.))

      # Heuristic default priors. Overriding these may dramatically
      # change inference performance and results.
      if level_scale_prior is None:
        level_scale_prior = tfd.LogNormal(
            loc=tf.math.log(.05 * observed_stddev),
            scale=3.,
            name='level_scale_prior')
      if slope_scale_prior is None:
        slope_scale_prior = tfd.LogNormal(
            loc=tf.math.log(.05 * observed_stddev),
            scale=3.,
            name='slope_scale_prior')
      if initial_level_prior is None:
        initial_level_prior = tfd.Normal(
            loc=observed_initial,
            scale=tf.abs(observed_initial) + observed_stddev,
            name='initial_level_prior')
      if initial_slope_prior is None:
        initial_slope_prior = tfd.Normal(
            loc=0., scale=observed_stddev, name='initial_slope_prior')

      tf.debugging.assert_same_float_dtype([
          level_scale_prior, slope_scale_prior, initial_level_prior,
          initial_slope_prior
      ])

      self._initial_state_prior = tfd.MultivariateNormalDiag(
          loc=tf.stack(
              [initial_level_prior.mean(),
               initial_slope_prior.mean()
              ], axis=-1),
          scale_diag=tf.stack([
              initial_level_prior.stddev(),
              initial_slope_prior.stddev()
          ], axis=-1))

      scaled_softplus = tfb.Chain([tfb.AffineScalar(scale=observed_stddev),
                                   tfb.Softplus()])
      super(LocalLinearTrend, self).__init__(
          parameters=[
              Parameter('level_scale', level_scale_prior, scaled_softplus),
              Parameter('slope_scale', slope_scale_prior, scaled_softplus)
          ],
          latent_size=2,
          name=name)
Example #19
0
 def test_forward_gradient(self):
     t = tf.range(1, 3, dtype=tf.float32)  # Shape [2]
     func = lambda t: tf.stack([t, t**2, t**3], axis=0)  # Shape [3, 2]
     fwd_grad = self.evaluate(math.fwd_gradient(func, t))
     self.assertEqual(fwd_grad.shape, (3, 2))
     np.testing.assert_allclose(fwd_grad, [[1., 1.], [2., 4.], [3., 12.]])
Example #20
0
def _batch_interp_with_gather_nd(x, x_ref_min, x_ref_max, y_ref, nd,
                                 fill_value, batch_dims):
    """N-D interpolation that works with leading batch dims."""
    dtype = x.dtype

    # In this function,
    # x.shape = [A1, ..., An, D, nd], where n = batch_dims
    # and
    # y_ref.shape = [A1, ..., An, C1, C2,..., Cnd, B1,...,BM]
    # y_ref[A1, ..., An, i1,...,ind] is a shape [B1,...,BM] Tensor with the value
    # at index [i1,...,ind] in the interpolation table.
    #  and x_ref_max have shapes [A1, ..., An, nd].

    # ny[k] is number of y reference points in interp dim k.
    ny = tf.cast(tf.shape(y_ref)[batch_dims:batch_dims + nd], dtype)

    # Map [x_ref_min, x_ref_max] to [0, ny - 1].
    # This is the (fractional) index of x.
    # x_idx_unclipped[A1, ..., An, d, k] is the fractional index into dim k of
    # interpolation table for the dth x value.
    x_ref_min_expanded = tf.expand_dims(x_ref_min, axis=-2)
    x_ref_max_expanded = tf.expand_dims(x_ref_max, axis=-2)
    x_idx_unclipped = (ny - 1) * (x - x_ref_min_expanded) / (
        x_ref_max_expanded - x_ref_min_expanded)

    # Wherever x is NaN, x_idx_unclipped will be NaN as well.
    # Keep track of the nan indices here (so we can impute NaN later).
    # Also eliminate any NaN indices, since there is not NaN in 32bit.
    nan_idx = tf.math.is_nan(x_idx_unclipped)
    x_idx_unclipped = tf.where(nan_idx, 0., x_idx_unclipped)

    # x_idx.shape = [A1, ..., An, D, nd]
    x_idx = tf.clip_by_value(x_idx_unclipped, tf.zeros((), dtype=dtype),
                             ny - 1)

    # Get the index above and below x_idx.
    # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx),
    # however, this results in idx_below == idx_above whenever x is on a grid.
    # This in turn results in y_ref_below == y_ref_above, and then the gradient
    # at this point is zero.  So here we 'jitter' one of idx_below, idx_above,
    # so that they are at different values.  This jittering does not affect the
    # interpolated value, but does make the gradient nonzero (unless of course
    # the y_ref values are the same).
    idx_below = tf.floor(x_idx)
    idx_above = tf.minimum(idx_below + 1, ny - 1)
    idx_below = tf.maximum(idx_above - 1, 0)

    # These are the values of y_ref corresponding to above/below indices.
    # idx_below_int32.shape = x.shape[:-1] + [nd]
    idx_below_int32 = tf.cast(idx_below, dtype=tf.int32)
    idx_above_int32 = tf.cast(idx_above, dtype=tf.int32)

    # idx_below_list is a length nd list of shape x.shape[:-1] int32 tensors.
    idx_below_list = tf.unstack(idx_below_int32, axis=-1)
    idx_above_list = tf.unstack(idx_above_int32, axis=-1)

    # Use t to get a convex combination of the below/above values.
    # t.shape = [A1, ..., An, D, nd]
    t = x_idx - idx_below

    # x, and tensors shaped like x, need to be added to, and selected with
    # (using tf.where) the output y.  This requires appending singletons.
    def _expand_x_fn(tensor):
        # Reshape tensor to tensor.shape + [1] * M.
        extended_shape = tf.concat([
            tf.shape(tensor),
            tf.ones_like(tf.shape(y_ref)[batch_dims + nd:])
        ],
                                   axis=0)
        return tf.reshape(tensor, extended_shape)

    # Now, t.shape = [A1, ..., An, D, nd] + [1] * (rank(y_ref) - nd - batch_dims)
    t = _expand_x_fn(t)
    s = 1 - t

    # Re-insert NaN wherever x was NaN.
    nan_idx = _expand_x_fn(nan_idx)
    t = tf.where(nan_idx, tf.constant(np.nan, dtype), t)

    terms = []
    # Our work above has located x's fractional index inside a cube of above/below
    # indices. The distance to the below indices is t, and to the above indices
    # is s.
    # Drawing lines from x to the cube walls, we get 2**nd smaller cubes. Each
    # term in the result is a product of a reference point, gathered from y_ref,
    # multiplied by a volume.  The volume is that of the cube opposite to the
    # reference point.  E.g. if the reference point is below x in every axis, the
    # volume is that of the cube with corner above x in every axis, s[0]*...*s[nd]
    # We could probably do this with one massive gather, but that would be very
    # unreadable and un-debuggable.  It also would create a large Tensor.
    for zero_ones_list in _binary_count(nd):
        gather_from_y_ref_idx = []
        opposite_volume_t_idx = []
        opposite_volume_s_idx = []
        for k, zero_or_one in enumerate(zero_ones_list):
            if zero_or_one == 0:
                # If the kth iterate has zero_or_one = 0,
                # Will gather from the 'below' reference point along axis k.
                gather_from_y_ref_idx.append(idx_below_list[k])
                # Now append the index to gather for computing opposite_volume.
                # This could be done by initializing opposite_volume to 1, then here:
                #  opposite_volume *= tf.gather(s, indices=k, axis=tf.rank(x) - 1)
                # but that puts a gather in the 'inner loop.'  Better to append the
                # index and do one larger gather down below.
                opposite_volume_s_idx.append(k)
            else:
                gather_from_y_ref_idx.append(idx_above_list[k])
                # Append an index to gather, having the same effect as
                #   opposite_volume *= tf.gather(t, indices=k, axis=tf.rank(x) - 1)
                opposite_volume_t_idx.append(k)

        # Compute opposite_volume (volume of cube opposite the ref point):
        # Recall t.shape = s.shape = [D, nd] + [1, ..., 1]
        # Gather from t and s along the 'nd' axis, which is rank(x) - 1.
        ov_axis = tf.rank(x) - 1
        opposite_volume = (tf.reduce_prod(
            tf.gather(t,
                      indices=tf.cast(opposite_volume_t_idx, dtype=tf.int32),
                      axis=ov_axis),
            axis=ov_axis) * tf.reduce_prod(tf.gather(
                s,
                indices=tf.cast(opposite_volume_s_idx, dtype=tf.int32),
                axis=ov_axis),
                                           axis=ov_axis))  # pyformat: disable

        y_ref_pt = tf.gather_nd(y_ref,
                                tf.stack(gather_from_y_ref_idx, axis=-1),
                                batch_dims=batch_dims)

        terms.append(y_ref_pt * opposite_volume)

    y = tf.math.add_n(terms)

    if tf.debugging.is_numeric_tensor(fill_value):
        # Recall x_idx_unclipped.shape = [D, nd],
        # so here we check if it was out of bounds in any of the nd dims.
        # Thus, oob_idx.shape = [D].
        oob_idx = tf.reduce_any(
            (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1), axis=-1)

        # Now, y.shape = [D, B1,...,BM], so we'll have to broadcast oob_idx.

        oob_idx = _expand_x_fn(oob_idx)  # Shape [D, 1,...,1]
        oob_idx |= tf.fill(tf.shape(y), False)
        y = tf.where(oob_idx, fill_value, y)
    return y
Example #21
0
        def collater_fn(batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
            """Collater function for relation classification task. See BaseTask."""
            def flatten_bsz(tensor):
                return tf.reshape(tensor, [bsz])

            new_batch = {
                'text_ids': batch['text_ids'],
                'text_mask': batch['text_mask'],
                'classifier_target': flatten_bsz(batch['target']),
            }

            # Sample mentions across batch

            # We want to make sure that the subject / object mentions always have
            # priority when we sample `max_batch_mentions` out of all available
            # mentions. Additionally, we want these subject / object  mentions to be
            # in the same order as their samples. In other words, we want the first
            # sampled mention to be object mention from the first sample, the second
            # sampled mention to be subject mention from the first sample, the third
            # sampled mention to be object mention from the second sample, etc.

            subj_index = flatten_bsz(batch['subject_mention_indices'])
            obj_index = flatten_bsz(batch['object_mention_indices'])

            # Adjust subject / object mention positions in individual samples to their
            # positions in flattened mentions.
            shift = tf.range(
                bsz, dtype=obj_index.dtype) * config.max_mentions_per_sample
            mention_target_indices = tf.reshape(
                tf.stack([subj_index + shift, obj_index + shift], axis=1),
                [-1])

            # Sample the rest of the mentions uniformly across batch
            scores = tf.random.uniform(shape=tf.shape(batch['mention_mask']))
            scores = scores * tf.cast(batch['mention_mask'], tf.float32)

            # We want to adjust scores for target mentions so they don't get sampled
            # for the second time. We achive this by making their scores negative.
            def set_negative_scores(scores, indices):
                indices_2d = tf.stack(
                    [tf.range(bsz, dtype=indices.dtype), indices], axis=1)
                return tf.tensor_scatter_nd_update(
                    scores, indices_2d, tf.fill(tf.shape(indices), -1.0))

            # Note that since we're using 2D scores (not yet flattened for simplicity)
            # we use unadjusted `subj_index` and `obj_index`.
            scores = set_negative_scores(scores, subj_index)
            scores = set_negative_scores(scores, obj_index)

            # There are `2 * bsz` target mentions which were already chosen
            num_to_sample = tf.maximum(max_batch_mentions - 2 * bsz, 0)
            sampled_scores, sampled_indices = tf.math.top_k(tf.reshape(
                scores, [-1]),
                                                            num_to_sample,
                                                            sorted=True)

            # Note that negative scores indicate that we have double-sampled some of
            # the target mentions (we set their scores to negative right above).
            # In this case, we remove them.
            num_not_double_sampled = tf.reduce_sum(
                tf.cast(tf.not_equal(sampled_scores, -1), tf.int32))
            sampled_indices = sampled_indices[:num_not_double_sampled]

            # Combine target mentions (subject / object) with sampled mentions
            mention_target_indices = tf.cast(mention_target_indices,
                                             sampled_indices.dtype)
            sampled_indices = tf.concat(
                [mention_target_indices, sampled_indices], axis=0)

            sampled_indices = mention_preprocess_utils.dynamic_padding_1d(
                sampled_indices, max_batch_mentions)

            dtype = batch['mention_start_positions'].dtype
            mention_mask = tf.reshape(batch['mention_mask'],
                                      [n_candidate_mentions])
            new_batch['mention_mask'] = tf.gather(mention_mask,
                                                  sampled_indices)
            new_batch['mention_start_positions'] = tf.gather(
                tf.reshape(batch['mention_start_positions'],
                           [n_candidate_mentions]), sampled_indices)
            new_batch['mention_end_positions'] = tf.gather(
                tf.reshape(batch['mention_end_positions'],
                           [n_candidate_mentions]), sampled_indices)
            new_batch['mention_batch_positions'] = tf.gather(
                tf.repeat(tf.range(bsz, dtype=dtype),
                          config.max_mentions_per_sample), sampled_indices)

            new_batch['mention_target_indices'] = tf.range(2 * bsz,
                                                           dtype=dtype)
            new_batch['mention_subject_indices'] = tf.range(bsz,
                                                            dtype=dtype) * 2
            new_batch['mention_object_indices'] = tf.range(bsz,
                                                           dtype=dtype) * 2 + 1

            if config.get('max_length_with_entity_tokens') is not None:
                batch_with_entity_tokens = mention_preprocess_utils.add_entity_tokens(
                    text_ids=new_batch['text_ids'],
                    text_mask=new_batch['text_mask'],
                    mention_mask=new_batch['mention_mask'],
                    mention_batch_positions=new_batch[
                        'mention_batch_positions'],
                    mention_start_positions=new_batch[
                        'mention_start_positions'],
                    mention_end_positions=new_batch['mention_end_positions'],
                    new_length=config.max_length_with_entity_tokens,
                )
                # Update `text_ids`, `text_mask`, `mention_mask`, `mention_*_positions`
                new_batch.update(batch_with_entity_tokens)
                # Update `max_length`
                max_length = config.max_length_with_entity_tokens
            else:
                max_length = encoder_config.max_length

            new_batch['mention_target_batch_positions'] = tf.gather(
                new_batch['mention_batch_positions'],
                new_batch['mention_target_indices'])
            new_batch['mention_target_start_positions'] = tf.gather(
                new_batch['mention_start_positions'],
                new_batch['mention_target_indices'])
            new_batch['mention_target_end_positions'] = tf.gather(
                new_batch['mention_end_positions'],
                new_batch['mention_target_indices'])
            new_batch['mention_target_weights'] = tf.ones(2 * bsz)

            # Fake IDs -- some encoders (ReadTwice) need them
            new_batch['mention_target_ids'] = tf.zeros(2 * bsz)

            new_batch['segment_ids'] = tf.zeros_like(new_batch['text_ids'])

            position_ids = tf.expand_dims(tf.range(max_length), axis=0)
            new_batch['position_ids'] = tf.tile(position_ids, (bsz, 1))

            return new_batch
Example #22
0
 def _event_shape_tensor(self):
     dimension = self.scale_operator.domain_dimension_tensor()
     return tf.stack([dimension, dimension])
Example #23
0
def concatenate_batch_into_sample(batch):
    for feature in batch.keys():
        batch[feature] = tf.reshape(batch[feature], [1, -1])
    return batch


for batch in dataset:
    concatenated_examples.append(concatenate_batch_into_sample(batch))


feature_dict = {}

for feature in concatenated_examples[0].keys():
    feature_list = [example[feature] for example in concatenated_examples]
    feature_dict[feature] = tf.squeeze(tf.stack(
        feature_list, axis=0))

feature_dict["f0_hz"] = feature_dict["f0_hz"].numpy()
if INTONATION:
    for di in range(feature_dict["f0_hz"].shape[0]):
        feature_dict["f0_hz"][di, :] = intonate(
            feature_dict["f0_hz"][di, :])


dataset = tf.data.Dataset.from_tensor_slices(feature_dict)

ex = next(iter(dataset))


assert ex["audio"].shape[0] == 16000*16
Example #24
0
 def call(self, inputs):
     if not isinstance(inputs, (list, tuple)):
         raise ValueError(
             'A merge layer should be called on a list of inputs. '
             f'Received: inputs={inputs} (not a list of tensors)')
     if self._reshape_required:
         reshaped_inputs = []
         input_ndims = list(map(backend.ndim, inputs))
         if None not in input_ndims:
             # If ranks of all inputs are available,
             # we simply expand each of them at axis=1
             # until all of them have the same rank.
             max_ndim = max(input_ndims)
             for x in inputs:
                 x_ndim = backend.ndim(x)
                 for _ in range(max_ndim - x_ndim):
                     x = tf.expand_dims(x, axis=1)
                 reshaped_inputs.append(x)
             return self._merge_function(reshaped_inputs)
         else:
             # Transpose all inputs so that batch size is the last dimension.
             # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
             transposed = False
             for x in inputs:
                 x_ndim = backend.ndim(x)
                 if x_ndim is None:
                     x_shape = tf.shape(x)
                     batch_size = x_shape[0]
                     new_shape = backend.concatenate(
                         [x_shape[1:],
                          tf.expand_dims(batch_size, axis=-1)])
                     x_transposed = tf.reshape(
                         x,
                         tf.stack([batch_size,
                                   tf.reduce_prod(x_shape[1:])],
                                  axis=0))
                     x_transposed = tf.transpose(x_transposed, perm=(1, 0))
                     x_transposed = tf.reshape(x_transposed, new_shape)
                     reshaped_inputs.append(x_transposed)
                     transposed = True
                 elif x_ndim > 1:
                     dims = list(range(1, x_ndim)) + [0]
                     reshaped_inputs.append(tf.transpose(x, perm=dims))
                     transposed = True
                 else:
                     # We don't transpose inputs if they are 1D vectors or scalars.
                     reshaped_inputs.append(x)
             y = self._merge_function(reshaped_inputs)
             y_ndim = backend.ndim(y)
             if transposed:
                 # If inputs have been transposed, we have to transpose the output too.
                 if y_ndim is None:
                     y_shape = tf.shape(y)
                     y_ndim = tf.shape(y_shape)[0]
                     batch_size = y_shape[y_ndim - 1]
                     new_shape = backend.concatenate([
                         tf.expand_dims(batch_size, axis=-1),
                         y_shape[:y_ndim - 1]
                     ])
                     y = tf.reshape(y, (-1, batch_size))
                     y = tf.transpose(y, perm=(1, 0))
                     y = tf.reshape(y, new_shape)
                 elif y_ndim > 1:
                     dims = [y_ndim - 1] + list(range(y_ndim - 1))
                     y = tf.transpose(y, perm=dims)
             return y
     else:
         return self._merge_function(inputs)
  def solve_nu_zeta(self,
                    dataset: dataset_lib.OffpolicyDataset,
                    target_policy: tf_policy.TFPolicy,
                    regularizer: float = 1e-6):
    """Solves for density ratios and then approximates target policy value.

    Args:
      dataset: The dataset to sample experience from.
      target_policy: The policy whose value we want to estimate.
      regularizer: A small constant to add to matrices before inverting them or
        to floats before taking square root.

    Returns:
      Estimated average per-step reward of the target policy.
    """

    if not hasattr(self, '_td_mat'):
      # Set up env_steps.
      episodes, valid_steps = dataset.get_all_episodes(
          limit=self._limit_episodes)
      total_num_steps_per_episode = tf.shape(valid_steps)[1] - 1
      num_episodes = tf.shape(valid_steps)[0]
      num_samples = num_episodes * total_num_steps_per_episode
      valid_and_not_last = tf.logical_and(valid_steps, episodes.discount > 0)
      valid_indices = tf.squeeze(
          tf.where(tf.reshape(valid_and_not_last[:, :-1], [-1])))

      initial_env_step = tf.nest.map_structure(
          lambda t: tf.squeeze(
              tf.reshape(
                  tf.repeat(
                      t[:, 0:1, ...],
                      axis=1,
                      repeats=total_num_steps_per_episode), [num_samples, -1])),
          episodes)
      initial_env_step = tf.nest.map_structure(
          lambda t: tf.gather(t, valid_indices), initial_env_step)
      tfagents_initial_env_step = dataset_lib.convert_to_tfagents_timestep(
          initial_env_step)

      env_step = tf.nest.map_structure(
          lambda t: tf.squeeze(
              tf.reshape(t[:, 0:total_num_steps_per_episode, ...],
                         [num_samples, -1])), episodes)
      env_step = tf.nest.map_structure(lambda t: tf.gather(t, valid_indices),
                                       env_step)
      tfagents_env_step = dataset_lib.convert_to_tfagents_timestep(env_step)

      next_env_step = tf.nest.map_structure(
          lambda t: tf.squeeze(
              tf.reshape(t[:, 1:total_num_steps_per_episode + 1, ...],
                         [num_samples, -1])), episodes)
      next_env_step = tf.nest.map_structure(
          lambda t: tf.gather(t, valid_indices), next_env_step)
      tfagents_next_env_step = dataset_lib.convert_to_tfagents_timestep(
          next_env_step)

      # get probabilities
      initial_target_probs = target_policy.distribution(
          tfagents_initial_env_step).action.probs_parameter()
      next_target_probs = target_policy.distribution(
          tfagents_next_env_step).action.probs_parameter()

      # First, get the nu_loss and data weights
      #current_nu_loss = self._get_nu_loss(initial_env_step, env_step,
      #                                    next_env_step, target_policy)
      #data_weight, _ = self._get_weights(current_nu_loss)

      # # debug only and to reproduce dual dice result, DELETE
      # data_weight = tf.ones_like(data_weight)

      state_action_count = self._get_state_action_counts(env_step)
      counts = tf.reduce_sum(tf.one_hot(state_action_count, self._dimension), 0)
      gamma_sample = tf.pow(self._gamma, tf.cast(env_step.step_num, tf.float32))

      # # debug only and to reproduce dual dice result, DELETE
      # gamma_sample = tf.ones_like(gamma_sample)

      # now we need to expand_dims to include action space in extra dimensions
      #data_weights = tf.reshape(data_weight, [-1, self._num_limits])
      # both are data sample weights for L2 problem, needs to be normalized later
      #gamma_data_weights = tf.reshape(gamma_sample, [-1, 1]) * data_weights

      initial_states = tf.tile(
          tf.reshape(initial_env_step.observation, [-1, 1]),
          [1, self._num_actions])
      initial_actions = tf.tile(
          tf.reshape(tf.range(self._num_actions), [1, -1]),
          [initial_env_step.observation.shape[0], 1])
      initial_nu_indices = self._get_index(initial_states, initial_actions)

      # linear term w.r.t. initial distribution
      #b_vec_2 = tf.stack([
      #    tf.reduce_sum(
      #        tf.reshape(
      #            data_weights[:, itr] / tf.reduce_sum(data_weights[:, itr]),
      #            [-1, 1]) * tf.reduce_sum(
      #                tf.one_hot(initial_nu_indices, self._dimension) *
      #                (1 - self._gamma) *
      #                tf.expand_dims(initial_target_probs, axis=-1),
      #                axis=1),
      #        axis=0) for itr in range(self._num_limits)
      #],
      #                   axis=0)

      next_states = tf.tile(
          tf.reshape(next_env_step.observation, [-1, 1]),
          [1, self._num_actions])
      next_actions = tf.tile(
          tf.reshape(tf.range(self._num_actions), [1, -1]),
          [next_env_step.observation.shape[0], 1])
      next_nu_indices = self._get_index(next_states, next_actions)
      next_nu_indices = tf.where(
          tf.expand_dims(next_env_step.is_absorbing(), -1),
          -1 * tf.ones_like(next_nu_indices), next_nu_indices)

      nu_indices = self._get_index(env_step.observation, env_step.action)

      target_log_probabilities = target_policy.distribution(
          tfagents_env_step).action.log_prob(env_step.action)
      if not self._solve_for_state_action_ratio:
        policy_ratio = tf.exp(target_log_probabilities -
                              env_step.get_log_probability())
      else:
        policy_ratio = tf.ones([
            target_log_probabilities.shape[0],
        ])
      policy_ratios = tf.tile(
          tf.reshape(policy_ratio, [-1, 1]), [1, self._num_actions])

      # the tabular feature vector
      a_vec = tf.one_hot(nu_indices, self._dimension) - tf.reduce_sum(
          self._gamma *
          tf.expand_dims(next_target_probs * policy_ratios, axis=-1) *
          tf.one_hot(next_nu_indices, self._dimension),
          axis=1)

      # linear term w.r.t. reward
      #b_vec_1 = tf.stack([
      #    tf.reduce_sum(
      #        tf.reshape(
      #            (gamma_data_weights[:, itr] /
      #             tf.reduce_sum(gamma_data_weights[:, itr])) * self._reward_fn(env_step), #/
      #            #tf.cast(state_action_count, tf.float32),
      #            [-1, 1]) * a_vec,
      #        axis=0) for itr in range(self._num_limits)
      #],
      #                   axis=0)
      # quadratic term of feature
      # Get weighted outer product by using einsum to save computing resource!
      #a_mat = tf.stack([
      #    tf.einsum(
      #        'ai, a, aj -> ij', a_vec,
      #        #1.0 / tf.cast(state_action_count, tf.float32),
      #        gamma_data_weights[:, itr] /
      #        tf.reduce_sum(gamma_data_weights[:, itr]),
      #        a_vec)
      #    for itr in range(self._num_limits)
      #],
      #                 axis=0)

      td_mat = tf.einsum('ai, a, aj -> ij',
                         tf.one_hot(nu_indices, self._dimension),
                         1.0 / tf.cast(state_action_count, tf.float32), a_vec)

      weighted_rewards = policy_ratio * self._reward_fn(env_step)

      bias = tf.reduce_sum(
          tf.one_hot(nu_indices, self._dimension) *
          tf.reshape(weighted_rewards, [-1, 1]) * 1.0 /
          tf.cast(state_action_count, tf.float32)[:, None],
          axis=0)

      # Initialize
      self._nu = np.ones_like(self._nu) * bias[:, None]
      self._nu2 = np.ones_like(self._nu2) * bias[:, None]

      self._a_vec = a_vec
      self._td_mat = td_mat
      self._bias = bias
      self._weighted_rewards = weighted_rewards
      self._state_action_count = state_action_count
      self._nu_indices = nu_indices
      self._initial_nu_indices = initial_nu_indices
      self._initial_target_probs = initial_target_probs
      self._gamma_sample = gamma_sample
      self._gamma_sample = tf.ones_like(gamma_sample)

    saddle_bellman_residuals = (
        tf.matmul(self._a_vec, self._nu) - self._weighted_rewards[:, None])
    saddle_bellman_residuals *= -1 * self._algae_alpha_sign
    saddle_zetas = tf.gather(self._zeta, self._nu_indices)
    saddle_initial_nu_values = tf.reduce_sum(  # Average over actions.
        self._initial_target_probs[:, :, None] *
        tf.gather(self._nu, self._initial_nu_indices),
        axis=1)
    saddle_init_nu_loss = ((1 - self._gamma) * saddle_initial_nu_values *
                           self._algae_alpha_sign)

    saddle_bellman_residuals2 = (
        tf.matmul(self._a_vec, self._nu2) - self._weighted_rewards[:, None])
    saddle_bellman_residuals2 *= 1 * self._algae_alpha_sign
    saddle_zetas2 = tf.gather(self._zeta2, self._nu_indices)
    saddle_initial_nu_values2 = tf.reduce_sum(  # Average over actions.
        self._initial_target_probs[:, :, None] *
        tf.gather(self._nu2, self._initial_nu_indices),
        axis=1)
    saddle_init_nu_loss2 = ((1 - self._gamma) * saddle_initial_nu_values2 * -1 *
                            self._algae_alpha_sign)

    saddle_loss = 0.5 * (
        saddle_init_nu_loss + saddle_bellman_residuals * saddle_zetas +
        -tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas) +
        -saddle_init_nu_loss2 + -saddle_bellman_residuals2 * saddle_zetas2 +
        tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas2))
    # Binary search to find best alpha.
    left = tf.constant([-8., -8.])
    right = tf.constant([32., 32.])
    for _ in range(16):
      mid = 0.5 * (left + right)
      self._alpha.assign(mid)
      weights, log_weights = self._get_weights(saddle_loss *
                                               self._gamma_sample[:, None])

      divergence = self._compute_divergence(weights, log_weights)
      divergence_violation = divergence - self._two_sided_limit
      left = tf.where(divergence_violation > 0., mid, left)
      right = tf.where(divergence_violation > 0., right, mid)
    self._alpha.assign(0.5 * (left + right))
    weights, log_weights = self._get_weights(saddle_loss *
                                             self._gamma_sample[:, None])

    gamma_data_weights = tf.stop_gradient(weights * self._gamma_sample[:, None])
    #print(tf.concat([gamma_data_weights, saddle_loss], axis=-1))
    avg_saddle_loss = (
        tf.reduce_sum(gamma_data_weights * saddle_loss, axis=0) /
        tf.reduce_sum(gamma_data_weights, axis=0))

    weighted_state_action_count = tf.reduce_sum(
        tf.one_hot(self._nu_indices, self._dimension)[:, :, None] *
        weights[:, None, :],
        axis=0)
    weighted_state_action_count = tf.gather(weighted_state_action_count,
                                            self._nu_indices)
    my_td_mat = tf.einsum(
        'ai, ab, ab, aj -> bij',
        tf.one_hot(self._nu_indices, self._dimension),
        #1.0 / tf.cast(self._state_action_count, tf.float32),
        1.0 / weighted_state_action_count,
        weights,
        self._a_vec)
    my_bias = tf.reduce_sum(
        tf.transpose(weights)[:, :, None] *
        tf.one_hot(self._nu_indices, self._dimension)[None, :, :] *
        tf.reshape(self._weighted_rewards, [1, -1, 1]) *
        #1.0 / tf.cast(self._state_action_count, tf.float32)[None, :, None],
        1.0 / tf.transpose(weighted_state_action_count)[:, :, None],
        axis=1)

    #print('hello', saddle_initial_nu_values[:1], saddle_zetas[:3],
    #      self._nu[:2], my_bias[:, :2], saddle_loss[:4])

    with tf.GradientTape(
        watch_accessed_variables=False, persistent=True) as tape:
      tape.watch([self._nu, self._nu2, self._alpha])
      bellman_residuals = tf.matmul(
          my_td_mat,
          tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None]
      bellman_residuals = tf.transpose(tf.squeeze(bellman_residuals, -1))
      bellman_residuals = tf.gather(bellman_residuals, self._nu_indices)
      initial_nu_values = tf.reduce_sum(  # Average over actions.
          self._initial_target_probs[:, :, None] *
          tf.gather(self._nu, self._initial_nu_indices),
          axis=1)

      bellman_residuals *= self._algae_alpha_sign

      init_nu_loss = ((1 - self._gamma) * initial_nu_values *
                      self._algae_alpha_sign)

      nu_loss = (
          tf.math.square(bellman_residuals) / 2.0 +
          tf.math.abs(self._algae_alpha) * init_nu_loss)

      loss = (
          gamma_data_weights * nu_loss /
          tf.reduce_sum(gamma_data_weights, axis=0, keepdims=True))

      bellman_residuals2 = tf.matmul(
          my_td_mat,
          tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :, None]
      bellman_residuals2 = tf.transpose(tf.squeeze(bellman_residuals2, -1))
      bellman_residuals2 = tf.gather(bellman_residuals2, self._nu_indices)
      initial_nu_values2 = tf.reduce_sum(  # Average over actions.
          self._initial_target_probs[:, :, None] *
          tf.gather(self._nu2, self._initial_nu_indices),
          axis=1)

      bellman_residuals2 *= -1 * self._algae_alpha_sign

      init_nu_loss2 = ((1 - self._gamma) * initial_nu_values2 * -1 *
                       self._algae_alpha_sign)

      nu_loss2 = (
          tf.math.square(bellman_residuals2) / 2.0 +
          tf.math.abs(self._algae_alpha) * init_nu_loss2)

      loss2 = (
          gamma_data_weights * nu_loss2 /
          tf.reduce_sum(gamma_data_weights, axis=0, keepdims=True))

      divergence = self._compute_divergence(weights, log_weights)
      divergence_violation = divergence - self._two_sided_limit

      alpha_loss = (-tf.exp(self._alpha) *
                    tf.stop_gradient(divergence_violation))

      extra_loss = tf.reduce_sum(tf.math.square(self._nu[-1, :]))
      extra_loss2 = tf.reduce_sum(tf.math.square(self._nu2[-1, :]))
      nu_grad = tape.gradient(loss + extra_loss, [self._nu])[0]
      nu_grad2 = tape.gradient(loss2 + extra_loss2, [self._nu2])[0]
    avg_loss = tf.reduce_sum(
        0.5 * (loss - loss2) / tf.math.abs(self._algae_alpha), axis=0)
    nu_jacob = tape.jacobian(nu_grad, [self._nu])[0]
    nu_hess = tf.stack([nu_jacob[:, i, :, i] for i in range(self._num_limits)],
                       axis=0)

    nu_jacob2 = tape.jacobian(nu_grad2, [self._nu2])[0]
    nu_hess2 = tf.stack(
        [nu_jacob2[:, i, :, i] for i in range(self._num_limits)], axis=0)

    for idx, div in enumerate(divergence):
      tf.summary.scalar('divergence%d' % idx, div)

    #alpha_grads = tape.gradient(alpha_loss, [self._alpha])
    #alpha_grad_op = self._alpha_optimizer.apply_gradients(
    #    zip(alpha_grads, [self._alpha]))
    #self._alpha.assign(tf.minimum(8., tf.maximum(-8., self._alpha)))

    #print(self._alpha, tf.concat([weights, nu_loss], -1))
    #regularizer = 0.1
    nu_transformed = tf.transpose(
        tf.squeeze(
            tf.linalg.solve(nu_hess + regularizer * tf.eye(self._dimension),
                            tf.expand_dims(-tf.transpose(nu_grad), axis=-1))))
    self._nu = self._nu + 0.1 * nu_transformed
    nu_transformed2 = tf.transpose(
        tf.squeeze(
            tf.linalg.solve(nu_hess2 + regularizer * tf.eye(self._dimension),
                            tf.expand_dims(-tf.transpose(nu_grad2), axis=-1))))
    self._nu2 = self._nu2 + 0.1 * nu_transformed2

    print(avg_loss * self._algae_alpha_sign,
          avg_saddle_loss * self._algae_alpha_sign, self._nu[:2], divergence)
    #print(init_nu_loss[:8], init_nu_loss[-8:])
    #print(bellman_residuals[:8])
    #print(self._nu[:3], self._zeta[:3])

    zetas = tf.matmul(my_td_mat,
                      tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None]
    zetas = tf.transpose(tf.squeeze(zetas, -1))
    zetas *= -self._algae_alpha_sign
    zetas /= tf.math.abs(self._algae_alpha)
    self._zeta = self._zeta + 0.1 * (zetas - self._zeta)

    zetas2 = tf.matmul(my_td_mat,
                       tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :,
                                                                      None]
    zetas2 = tf.transpose(tf.squeeze(zetas2, -1))
    zetas2 *= 1 * self._algae_alpha_sign
    zetas2 /= tf.math.abs(self._algae_alpha)
    self._zeta2 = self._zeta2 + 0.1 * (zetas2 - self._zeta2)

    #self._zeta = (
    #    tf.einsum('ij,ja-> ia', self._td_mat, self._nu) -
    #    tf.transpose(my_bias))
    #self._zeta *= -tf.reshape(self._algae_alpha_sign, [1, self._num_limits])
    #self._zeta /= tf.math.abs(self._algae_alpha)
    return [
        avg_saddle_loss * self._algae_alpha_sign,
        avg_loss * self._algae_alpha_sign, divergence
    ]
def per_replica_to_tensor(value):
    return tf.nest.map_structure(
        lambda per_replica: tf.stack(per_replica.values, axis=0), value)
Example #27
0
def barrier_price(*,
                  volatilities: types.RealTensor,
                  strikes: types.RealTensor,
                  expiries: types.RealTensor,
                  spots: types.RealTensor,
                  barriers: types.RealTensor,
                  rebates: types.RealTensor = None,
                  discount_rates: types.RealTensor = None,
                  dividend_rates: types.RealTensor = None,
                  is_barrier_down: types.BoolTensor = None,
                  is_knock_out: types.BoolTensor = None,
                  is_call_options: types.BoolTensor = None,
                  dtype: tf.DType = None,
                  name: str = None) -> types.RealTensor:
  """Prices barrier options in a Black-Scholes Model.

  Computes the prices of options with a single barrier in Black-Scholes world as
  described in Ref. [1]. Note that the barrier is applied continuously.

  #### Example

  This example is taken from Ref. [2], Page 154.

  ```python
  import tf_quant_finance as tff

  dtype = np.float32
  discount_rates = np.array([.08, .08])
  dividend_rates = np.array([.04, .04])
  spots = np.array([100., 100.])
  strikes = np.array([90., 90.])
  barriers = np.array([95. 95.])
  rebates = np.array([3. 3.])
  volatilities = np.array([.25, .25])
  expiries = np.array([.5, .5])
  barriers_type = np.array([5, 1])
  is_barrier_down = np.array([True, False])
  is_knock_out = np.array([False, False])
  is_call_option = np.array([True, True])

  price = tff.black_scholes.barrier_price(
    discount_rates, dividend_rates, spots, strikes,
    barriers, rebates, volatilities,
    expiries, is_barrier_down, is_knock_out, is_call_options)

  # Expected output
  #  `Tensor` with values [9.024, 7.7627]
  ```

  #### References

  [1]: Lee Clewlow, Javier Llanos, Chris Strickland, Caracas Venezuela
    Pricing Exotic Options in a Black-Scholes World, 1994
    https://warwick.ac.uk/fac/soc/wbs/subjects/finance/research/wpaperseries/1994/94-54.pdf
  [2]: Espen Gaarder Haug, The Complete Guide to Option Pricing Formulas,
    2nd Edition, 1997

  Args:
    volatilities: Real `Tensor` of any shape and dtype. The volatilities to
      expiry of the options to price.
    strikes: A real `Tensor` of the same dtype and compatible shape as
      `volatilities`. The strikes of the options to be priced.
    expiries: A real `Tensor` of same dtype and compatible shape as
      `volatilities`. The expiry of each option. The units should be such that
      `expiry * volatility**2` is dimensionless.
    spots: A real `Tensor` of any shape that broadcasts to the shape of the
      `volatilities`. The current spot price of the underlying.
    barriers: A real `Tensor` of same dtype as the `volatilities` and of the
      shape that broadcasts with `volatilities`. The barriers of each option.
    rebates: A real `Tensor` of same dtype as the `volatilities` and of the
      shape that broadcasts with `volatilities`. For knockouts, this is a
      fixed cash payout in case the barrier is breached. For knockins, this is a
      fixed cash payout in case the barrier level is not breached. In the former
      case, the rebate is paid immediately on breach whereas in the latter, the
      rebate is paid at the expiry of the option.
      Default value: `None` which maps to no rebates.
    discount_rates: A real `Tensor` of same dtype as the
      `volatilities` and of the shape that broadcasts with `volatilities`.
      Discount rates, or risk free rates.
      Default value: `None`, equivalent to discount_rate = 0.
    dividend_rates: A real `Tensor` of same dtype as the
      `volatilities` and of the shape that broadcasts with `volatilities`. A
      continuous dividend rate paid by the underlier. If `None`, then
      defaults to zero dividends.
      Default value: `None`, equivalent to zero dividends.
    is_barrier_down: A real `Tensor` of `boolean` values and of the shape
      that broadcasts with `volatilities`. True if barrier is below asset
      price at expiration.
      Default value: `True`.
    is_knock_out: A real `Tensor` of `boolean` values and of the shape
      that broadcasts with `volatilities`. True if option is knock out
      else false.
      Default value: `True`.
    is_call_options: A real `Tensor` of `boolean` values and of the shape
      that broadcasts with `volatilities`. True if option is call else
      false.
      Default value: `True`.
    dtype: Optional `tf.DType`. If supplied, the dtype to be used for conversion
      of any supplied non-`Tensor` arguments to `Tensor`.
      Default value: `None` which maps to the default dtype inferred by
      TensorFlow.
    name: str. The name for the ops created by this function.
      Default value: `None` which is mapped to the default name `barrier_price`.
  Returns:
    option_prices: A `Tensor` of same shape as `spots`. The approximate price of
    the barriers option under black scholes.
  """
  # The computation is done as in Ref [2] where each integral is split into
  # two matrices. The first matrix contains the algebraic terms and the second
  # matrix contains the probability distribution terms. Masks are used to filter
  # appropriate terms for calculating the integral. Then a dot product of each
  # row in the matricies coupled with the masks work to calculate the prices of
  # the barriers option.
  with tf.name_scope(name or 'barrier_price'):
    spots = tf.convert_to_tensor(spots, dtype=dtype, name='spots')
    dtype = spots.dtype
    strikes = tf.convert_to_tensor(strikes, dtype=dtype, name='strikes')
    volatilities = tf.convert_to_tensor(
        volatilities, dtype=dtype, name='volatilities')
    expiries = tf.convert_to_tensor(expiries, dtype=dtype, name='expiries')
    barriers = tf.convert_to_tensor(barriers, dtype=dtype, name='barriers')
    if rebates is not None:
      rebates = tf.convert_to_tensor(rebates, dtype=dtype, name='rebates')
    else:
      rebates = tf.zeros_like(spots, dtype=dtype, name='rebates')

    # Convert all to tensor and enforce float dtype where required
    if discount_rates is not None:
      discount_rates = tf.convert_to_tensor(
          discount_rates, dtype=dtype, name='discount_rates')
    else:
      discount_rates = tf.convert_to_tensor(
          0.0, dtype=dtype, name='discount_rates')

    if dividend_rates is not None:
      dividend_rates = tf.convert_to_tensor(
          dividend_rates, dtype=dtype, name='dividend_rates')
    else:
      dividend_rates = tf.convert_to_tensor(
          0.0, dtype=dtype, name='dividend_rates')

    if is_barrier_down is None:
      is_barrier_down = tf.constant(1, name='is_barrier_down')
    else:
      is_barrier_down = tf.convert_to_tensor(is_barrier_down, dtype=tf.bool,
                                             name='is_barrier_down')
      is_barrier_down = tf.where(is_barrier_down, 1, 0)
    if is_knock_out is None:
      is_knock_out = tf.constant(1, name='is_knock_out')
    else:
      is_knock_out = tf.convert_to_tensor(is_knock_out, dtype=tf.bool,
                                          name='is_knock_out')
      is_knock_out = tf.where(is_knock_out, 1, 0)
    if is_call_options is None:
      is_call_options = tf.constant(1, name='is_call_options')
    else:
      is_call_options = tf.convert_to_tensor(is_call_options, dtype=tf.bool,
                                             name='is_call_options')
      is_call_options = tf.where(is_call_options, 1, 0)

    # Indices which range from 0-7 are used to select the appropriate
    # mask for each barrier
    indices = tf.bitwise.left_shift(
        is_barrier_down, 2) + tf.bitwise.left_shift(
            is_knock_out, 1) + is_call_options

    # Masks select the appropriate terms for integral approximations
    # Integrals are separated by algebraic terms and probability
    # distribution terms. This give 12 different terms per matrix
    # (6 integrals, 2 terms each)
    # shape = [8, 12]
    mask_matrix_greater_strike = tf.constant([
        [1, 1, -1, -1, 0, 0, 1, 1, 1, 1, 0, 0],  # up and in put
        [1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],  # up and in call
        [0, 0, 1, 1, 0, 0, -1, -1, 0, 0, 1, 1],  # up and out put
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],  # up and out call
        [0, 0, 1, 1, -1, -1, 1, 1, 0, 0, 1, 1],  # down and in put
        [0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0],  # down and in call
        [1, 1, -1, -1, 1, 1, -1, -1, 0, 0, 1, 1],  # down and out put
        [1, 1, 0, 0, -1, -1, 0, 0, 0, 0, 1, 1]])  # down and out call

    mask_matrix_lower_strike = tf.constant([
        [0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0],  # up and in put
        [0, 0, 1, 1, -1, -1, 1, 1, 1, 1, 0, 0],  # up and in call
        [1, 1, 0, 0, -1, -1, 0, 0, 0, 0, 1, 1],  # up and out put
        [1, 1, -1, -1, 1, 1, -1, -1, 0, 0, 1, 1],  # up and out call
        [1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],  # down and in put
        [1, 1, -1, -1, 0, 0, 1, 1, 1, 1, 0, 0],  # down and in call
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],  # down and out put
        [0, 0, 1, 1, 0, 0, -1, -1, 0, 0, 1, 1]])  # down and out call

    # Create masks
    # Masks are shape [strikes.shape, 12]
    masks_lower = tf.gather(mask_matrix_lower_strike, indices, axis=0)
    masks_greater = tf.gather(mask_matrix_greater_strike, indices, axis=0)
    strikes_greater = tf.expand_dims(strikes > barriers, axis=-1)
    masks = tf.where(strikes_greater, masks_greater, masks_lower)
    masks = tf.cast(masks, dtype=dtype)
    one = tf.constant(1, dtype=dtype)
    call_or_put = tf.cast(tf.where(tf.equal(is_call_options, 0), -one, one),
                          dtype=dtype)
    below_or_above = tf.cast(tf.where(tf.equal(is_barrier_down, 0), -one, one),
                             dtype=dtype)

    # Calculate params for integrals
    sqrt_var = volatilities * tf.math.sqrt(expiries)
    mu = (discount_rates - dividend_rates) - ((volatilities**2) / 2)
    lamda = 1 + (mu / (volatilities**2))
    x = (tf.math.log(spots / strikes) / (sqrt_var)) + (lamda * sqrt_var)
    x1 = (tf.math.log(spots / barriers) / (sqrt_var)) + (lamda * sqrt_var)
    y = (tf.math.log((barriers**2) / (spots * strikes)) / (
        sqrt_var)) + (lamda * sqrt_var)
    y1 = (tf.math.log(barriers / spots) / (sqrt_var)) + (lamda * sqrt_var)
    b = ((mu**2) + (2 * (volatilities**2) * discount_rates)) / (volatilities**2)
    z = (tf.math.log(barriers / spots) / (sqrt_var)) + (b * sqrt_var)
    a = mu / (volatilities**2)

    # Other params used for integrals
    discount_factors = tf.math.exp(
        -discount_rates * expiries, name='discount_factors')
    barriers_ratio = tf.math.divide(barriers, spots, name='barriers_ratio')
    spots_term = call_or_put * spots * tf.math.exp(-dividend_rates * expiries)
    strikes_term = call_or_put * strikes * discount_factors

    # rank is used to stack elements and reduce_sum
    strike_rank = strikes.shape.rank

    # Constructing Matrix with first and second algebraic terms for each
    # integral [strike.shape, 12]
    terms_mat = tf.stack(
        (spots_term, -strikes_term,
         spots_term, -strikes_term,
         spots_term * (barriers_ratio**(2 * lamda)),
         -strikes_term * (barriers_ratio**((2 * lamda) - 2)),
         spots_term * (barriers_ratio**(2 * lamda)),
         -strikes_term * (barriers_ratio**((2 * lamda) - 2)),
         rebates * discount_factors,
         -rebates * discount_factors * (  # pylint: disable=invalid-unary-operand-type
             barriers_ratio**((2 * lamda) - 2)),
         rebates * (barriers_ratio**(a + b)),
         rebates * (barriers_ratio**(a - b))),
        name='term_matrix', axis=strike_rank)

    # Constructing Matrix with first and second norm for each integral
    # [strikes.shape, 12]
    cdf_mat = tf.stack(
        (call_or_put * x,
         call_or_put * (x - sqrt_var),
         call_or_put * x1,
         call_or_put * (x1 - sqrt_var),
         below_or_above * y,
         below_or_above * (y - sqrt_var),
         below_or_above * y1,
         below_or_above * (y1 - sqrt_var),
         below_or_above * (x1 - sqrt_var),
         below_or_above * (y1 - sqrt_var),
         below_or_above * z,
         below_or_above * (z - (2 * b * sqrt_var))),
        name='cdf_matrix', axis=strike_rank)
    cdf_mat = _ncdf(cdf_mat)
    # Calculating and returning price for each option
    return tf.reduce_sum(masks * terms_mat * cdf_mat, axis=strike_rank)
Example #28
0
    def _loop_tree_doubling(self, step_size, momentum_state_memory,
                            current_step_meta_info, iter_, initial_step_state,
                            initial_step_metastate, seed):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            (direction_seed, subtree_seed, acceptance_seed,
             next_seed) = samplers.split_seed(seed, n=4)
            batch_shape = ps.shape(current_step_meta_info.init_energy)
            direction = tf.cast(samplers.uniform(shape=batch_shape,
                                                 minval=0,
                                                 maxval=2,
                                                 dtype=tf.int32,
                                                 seed=direction_seed),
                                dtype=tf.bool)

            tree_start_states = tf.nest.map_structure(
                lambda v: bu.where_left_justified_mask(direction, v[1], v[0]),
                initial_step_state)

            directions_expanded = [
                bu.left_justified_expand_dims_like(direction, state)
                for state in tree_start_states.state
            ]

            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(d, ss, -ss)
                    for d, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state, tree_final_states, final_not_divergence,
                continue_tree_final, energy_diff_tree_sum,
                momentum_subtree_cumsum, leapfrogs_taken
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                current_step_meta_info,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory,
                seed=subtree_seed)

            last_candidate_state = initial_step_metastate.candidate_state

            energy_diff_sum = (energy_diff_tree_sum +
                               initial_step_metastate.energy_diff_sum)
            if MULTINOMIAL_SAMPLE:
                tree_weight = tf.where(
                    continue_tree_final, candidate_tree_state.weight,
                    tf.constant(-np.inf,
                                dtype=candidate_tree_state.weight.dtype))
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                tree_weight = tf.where(continue_tree_final,
                                       candidate_tree_state.weight,
                                       tf.zeros([], dtype=TREE_COUNT_DTYPE))
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-samplers.uniform(shape=batch_shape,
                                                dtype=log_accept_thresh.dtype,
                                                seed=acceptance_seed))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    bu.where_left_justified_mask(choose_new_state, s0, s1)
                    for s0, s1 in zip(candidate_tree_state.state,
                                      last_candidate_state.state)
                ],
                target=bu.where_left_justified_mask(
                    choose_new_state, candidate_tree_state.target,
                    last_candidate_state.target),
                target_grad_parts=[
                    bu.where_left_justified_mask(choose_new_state, grad0,
                                                 grad1)
                    for grad0, grad1 in zip(
                        candidate_tree_state.target_grad_parts,
                        last_candidate_state.target_grad_parts)
                ],
                energy=bu.where_left_justified_mask(
                    choose_new_state, candidate_tree_state.energy,
                    last_candidate_state.energy),
                weight=weight_sum)

            for new_candidate_state_temp, old_candidate_state_temp in zip(
                    new_candidate_state.state, last_candidate_state.state):
                tensorshape_util.set_shape(new_candidate_state_temp,
                                           old_candidate_state_temp.shape)

            for new_candidate_grad_temp, old_candidate_grad_temp in zip(
                    new_candidate_state.target_grad_parts,
                    last_candidate_state.target_grad_parts):
                tensorshape_util.set_shape(new_candidate_grad_temp,
                                           old_candidate_grad_temp.shape)

            # Update left right information of the trajectory, and check trajectory
            # level U turn
            tree_otherend_states = tf.nest.map_structure(
                lambda v: bu.where_left_justified_mask(direction, v[0], v[1]),
                initial_step_state)

            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    tf.stack(
                        [  # pylint: disable=g-complex-comprehension
                            bu.where_left_justified_mask(
                                direction, right, left),
                            bu.where_left_justified_mask(
                                direction, left, right),
                        ],
                        axis=0) for left, right in zip(
                            tf.nest.flatten(tree_final_states),
                            tf.nest.flatten(tree_otherend_states))
                ])

            momentum_tree_cumsum = []
            for p0, p1 in zip(initial_step_metastate.momentum_sum,
                              momentum_subtree_cumsum):
                momentum_part_temp = p0 + p1
                tensorshape_util.set_shape(momentum_part_temp, p0.shape)
                momentum_tree_cumsum.append(momentum_part_temp)

            for new_state_temp, old_state_temp in zip(
                    tf.nest.flatten(new_step_state),
                    tf.nest.flatten(initial_step_state)):
                tensorshape_util.set_shape(new_state_temp,
                                           old_state_temp.shape)

            if GENERALIZED_UTURN:
                state_diff = momentum_tree_cumsum
            else:
                state_diff = [s[1] - s[0] for s in new_step_state.state]

            no_u_turns_trajectory = has_not_u_turn(
                state_diff, [m[0] for m in new_step_state.momentum],
                [m[1] for m in new_step_state.momentum],
                log_prob_rank=ps.rank_from_shape(batch_shape),
                shard_axis_names=self.experimental_shard_axis_names)

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                momentum_sum=momentum_tree_cumsum,
                energy_diff_sum=energy_diff_sum,
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, next_seed, new_step_state, new_step_metastate
Example #29
0
 def observation_jacobian_fn_3dim(x):
     return tf.reshape(
         tf.stack([1., 0., 1., 1., x[..., 1], x[..., 0]], axis=-1),
         [3, 2])
Example #30
0
def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs):
  del args, kwargs
  # TODO(nareshmodi): Consider a collective op to gather the tensors from the
  # various devices for performance reasons.
  return tf.stack(value.tensors)