def lu_solve(lower_upper, perm, rhs, validate_args=False, name=None): """Solves systems of linear eqns `A X = RHS`, given LU factorizations. Note: this function does not verify the implied matrix is actually invertible nor is this condition checked even when `validate_args=True`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. rhs: Matrix-shaped float `Tensor` representing targets for which to solve; `A X = RHS`. To handle vector cases, use: `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Note: this function does not verify the implied matrix is actually invertible, even when `validate_args=True`. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_solve'). Returns: x: The `X` in `A @ X = RHS`. #### Examples ```python import numpy as np import tensorflow as tf import tensorflow_probability as tfp x = [[[1., 2], [3, 4]], [[7, 8], [3, 4]]] inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2)) tf.assert_near(tf.matrix_inverse(x), inv_x) # ==> True ``` """ with tf.name_scope(name or 'lu_solve'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') rhs = tf.convert_to_tensor(rhs, dtype_hint=lower_upper.dtype, name='rhs') assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) rhs = tf.identity(rhs) if (tensorshape_util.rank(rhs.shape) == 2 and tensorshape_util.rank(perm.shape) == 1): # Both rhs and perm have scalar batch_shape. permuted_rhs = tf.gather(rhs, perm, axis=-2) else: # Either rhs or perm have non-scalar batch_shape or we can't determine # this information statically. rhs_shape = tf.shape(rhs) broadcast_batch_shape = tf.broadcast_dynamic_shape( rhs_shape[:-2], tf.shape(perm)[:-1]) d, m = rhs_shape[-2], rhs_shape[-1] rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]], axis=0) # Tile out rhs. broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape) broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m]) # Tile out perm and add batch indices. broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1]) broadcast_perm = tf.reshape(broadcast_perm, [-1, d]) broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape) broadcast_batch_indices = tf.broadcast_to( tf.range(broadcast_batch_size)[:, tf.newaxis], [broadcast_batch_size, d]) broadcast_perm = tf.stack( [broadcast_batch_indices, broadcast_perm], axis=-1) permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm) permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype)) return tf.linalg.triangular_solve( lower_upper, # Only upper is accessed. tf.linalg.triangular_solve(lower, permuted_rhs), lower=False)
def _coord_grid_to_mesh_grid(coord_grid): if len(coord_grid) == 1: return tf.expand_dims(coord_grid[0], -1) return tf.stack(values=tf.meshgrid(*coord_grid, indexing="ij"), axis=-1)
def call(self, x, training=False): x_flat = tf.reshape(x, shape=(-1, self.depth)) # Split each input vector into one segment per head. x_flat_split = tf.split(x_flat, self.num_heads, axis=1) x_flat = tf.concat(x_flat_split, axis=0) if training: # Figure out which centroids we want to keep, and which we want to # restart. n = x_flat.shape[0] keep = self.counts * self.k > self.restart_threshold * n restart = tf.math.logical_not(keep) # Replace centroids to restart with elements from the batch, using samples # from a uniform distribution as a fallback in case we need to restart # more centroids than we have elements in the batch. restart_idx = tf.squeeze(tf.where(restart), -1) n_replace = tf.minimum(tf.shape(restart_idx)[0], x_flat.shape[0]) e_restart = tf.tensor_scatter_nd_update( tf.random.uniform([self.k, self.depth // self.num_heads]), tf.expand_dims(restart_idx[:n_replace], 1), tf.random.shuffle(x_flat)[:n_replace]) # Compute the values of the centroids we want to keep by dividing the # summed vectors by the corresponding counts. e = tf.where( tf.expand_dims(keep, 1), tf.math.divide_no_nan(self.sums, tf.expand_dims(self.counts, 1)), e_restart) else: # If not training, just use the centroids as is with no restarts. e = tf.math.divide_no_nan(self.sums, tf.expand_dims(self.counts, 1)) # Compute distance between each input vector and each cluster center. distances = (tf.expand_dims(tf.reduce_sum(x_flat**2, axis=1), 1) - 2 * tf.matmul(x_flat, tf.transpose(e)) + tf.expand_dims(tf.reduce_sum(e**2, axis=1), 0)) # Find nearest cluster center for each input vector. c = tf.argmin(distances, axis=1) # Quantize input vectors with straight-through estimator. z = tf.nn.embedding_lookup(e, c) z_split = tf.split(z, self.num_heads, axis=0) z = tf.concat(z_split, axis=1) z = tf.reshape(z, tf.shape(x)) z = x + tf.stop_gradient(z - x) if training: # Compute cluster counts and vector sums over the batch. oh = tf.one_hot(indices=c, depth=self.k) counts = tf.reduce_sum(oh, axis=0) sums = tf.matmul(oh, x_flat, transpose_a=True) # Apply exponential moving average to cluster counts and vector sums. self.counts.assign_sub((1 - self.gamma) * (self.counts - counts)) self.sums.assign_sub((1 - self.gamma) * (self.sums - sums)) c_split = tf.split(c, self.num_heads, axis=0) c = tf.stack(c_split, axis=1) c = tf.reshape(c, tf.concat([tf.shape(x)[:-1], [self.num_heads]], axis=0)) return z, c
def set_negative_scores(scores, indices): indices_2d = tf.stack( [tf.range(bsz, dtype=indices.dtype), indices], axis=1) return tf.tensor_scatter_nd_update( scores, indices_2d, tf.fill(tf.shape(indices), -1.0))
def update_confusion_matrix_variables(variables_to_update, y_true, y_pred, thresholds, top_k=None, class_id=None, sample_weight=None, multi_label=False, label_weights=None): """Returns op to update the given confusion matrix variables. For every pair of values in y_true and y_pred: true_positive: y_true == True and y_pred > thresholds false_negatives: y_true == True and y_pred <= thresholds true_negatives: y_true == False and y_pred <= thresholds false_positive: y_true == False and y_pred > thresholds The results will be weighted and added together. When multiple thresholds are provided, we will repeat the same for every threshold. For estimation of these metrics over a stream of data, the function creates an `update_op` operation that updates the given variables. If `sample_weight` is `None`, weights default to 1. Use weights of 0 to mask values. Args: variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys and corresponding variables to update as values. y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`. y_pred: A floating point `Tensor` of arbitrary shape and whose values are in the range `[0, 1]`. thresholds: A float value, float tensor, python list, or tuple of float thresholds in `[0, 1]`, or NEG_INF (used when top_k is set). top_k: Optional int, indicates that the positive labels should be limited to the top k predictions. class_id: Optional int, limits the prediction and labels to the class specified by this argument. sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must be either `1`, or the same as the corresponding `y_true` dimension). multi_label: Optional boolean indicating whether multidimensional prediction/labels should be treated as multilabel responses, or flattened into a single label. When True, the valus of `variables_to_update` must have a second dimension equal to the number of labels in y_true and y_pred, and those tensors must not be RaggedTensors. label_weights: (optional) tensor of non-negative weights for multilabel data. The weights are applied when calculating TP, FP, FN, and TN without explicit multilabel handling (i.e. when the data is to be flattened). Returns: Update op. Raises: ValueError: If `y_pred` and `y_true` have mismatched shapes, or if `sample_weight` is not `None` and its shape doesn't match `y_pred`, or if `variables_to_update` contains invalid keys. """ if multi_label and label_weights is not None: raise ValueError( '`label_weights` for multilabel data should be handled ' 'outside of `update_confusion_matrix_variables` when ' '`multi_label` is True.') if variables_to_update is None: return if not any(key for key in variables_to_update if key in list(ConfusionMatrix)): raise ValueError( 'Please provide at least one valid confusion matrix ' 'variable to update. Valid variable key options are: "{}". ' 'Received: "{}"'.format(list(ConfusionMatrix), variables_to_update.keys())) variable_dtype = list(variables_to_update.values())[0].dtype y_true = tf.cast(y_true, dtype=variable_dtype) y_pred = tf.cast(y_pred, dtype=variable_dtype) thresholds = tf.convert_to_tensor(thresholds, dtype=variable_dtype) num_thresholds = thresholds.shape[0] if multi_label: one_thresh = tf.equal(tf.cast(1, dtype=tf.int32), tf.rank(thresholds), name='one_set_of_thresholds_cond') else: [y_pred, y_true ], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true], sample_weight) one_thresh = tf.cast(True, dtype=tf.bool) invalid_keys = [ key for key in variables_to_update if key not in list(ConfusionMatrix) ] if invalid_keys: raise ValueError( 'Invalid keys: {}. Valid variable key options are: "{}"'.format( invalid_keys, list(ConfusionMatrix))) with tf.control_dependencies([ tf.compat.v1.assert_greater_equal( y_pred, tf.cast(0.0, dtype=y_pred.dtype), message='predictions must be >= 0'), tf.compat.v1.assert_less_equal(y_pred, tf.cast(1.0, dtype=y_pred.dtype), message='predictions must be <= 1') ]): if sample_weight is None: y_pred, y_true = losses_utils.squeeze_or_expand_dimensions( y_pred, y_true) else: sample_weight = tf.cast(sample_weight, dtype=variable_dtype) y_pred, y_true, sample_weight = ( losses_utils.squeeze_or_expand_dimensions( y_pred, y_true, sample_weight=sample_weight)) y_pred.shape.assert_is_compatible_with(y_true.shape) if top_k is not None: y_pred = _filter_top_k(y_pred, top_k) if class_id is not None: y_true = y_true[..., class_id] y_pred = y_pred[..., class_id] pred_shape = tf.compat.v1.shape(y_pred) num_predictions = pred_shape[0] if y_pred.shape.ndims == 1: num_labels = 1 else: num_labels = tf.raw_ops.Prod(input=pred_shape[1:], axis=0) thresh_label_tile = tf.compat.v1.cond(one_thresh, lambda: num_labels, lambda: tf.cast(1, dtype=tf.int32)) # Reshape predictions and labels, adding a dim for thresholding. if multi_label: predictions_extra_dim = tf.compat.v1.expand_dims(y_pred, 0) labels_extra_dim = tf.compat.v1.expand_dims( tf.cast(y_true, dtype=tf.bool), 0) else: # Flatten predictions and labels when not multilabel. predictions_extra_dim = tf.reshape(y_pred, [1, -1]) labels_extra_dim = tf.reshape(tf.cast(y_true, dtype=tf.bool), [1, -1]) # Tile the thresholds for every prediction. if multi_label: thresh_pretile_shape = [num_thresholds, 1, -1] thresh_tiles = [1, num_predictions, thresh_label_tile] data_tiles = [num_thresholds, 1, 1] else: thresh_pretile_shape = [num_thresholds, -1] thresh_tiles = [1, num_predictions * num_labels] data_tiles = [num_thresholds, 1] thresh_tiled = tf.tile(tf.reshape(thresholds, thresh_pretile_shape), tf.stack(thresh_tiles)) # Tile the predictions for every threshold. preds_tiled = tf.tile(predictions_extra_dim, data_tiles) # Compare predictions and threshold. pred_is_pos = tf.greater(preds_tiled, thresh_tiled) # Tile labels by number of thresholds label_is_pos = tf.tile(labels_extra_dim, data_tiles) if sample_weight is not None: sample_weight = tf.__internal__.ops.broadcast_weights( tf.cast(sample_weight, dtype=variable_dtype), y_pred) weights_tiled = tf.tile(tf.reshape(sample_weight, thresh_tiles), data_tiles) else: weights_tiled = None if label_weights is not None and not multi_label: label_weights = tf.compat.v1.expand_dims(label_weights, 0) label_weights = tf.__internal__.ops.broadcast_weights( label_weights, y_pred) label_weights_tiled = tf.tile(tf.reshape(label_weights, thresh_tiles), data_tiles) if weights_tiled is None: weights_tiled = label_weights_tiled else: weights_tiled = tf.multiply(weights_tiled, label_weights_tiled) update_ops = [] def weighted_assign_add(label, pred, weights, var): label_and_pred = tf.cast(tf.logical_and(label, pred), dtype=var.dtype) if weights is not None: label_and_pred *= tf.cast(weights, dtype=var.dtype) return var.assign_add(tf.reduce_sum(label_and_pred, 1)) loop_vars = { ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos), } update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update if update_fn or update_tn: pred_is_neg = tf.logical_not(pred_is_pos) loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg) if update_fp or update_tn: label_is_neg = tf.logical_not(label_is_pos) loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos) if update_tn: loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (label_is_neg, pred_is_neg) for matrix_cond, (label, pred) in loop_vars.items(): if matrix_cond in variables_to_update: update_ops.append( weighted_assign_add(label, pred, weights_tiled, variables_to_update[matrix_cond])) return tf.group(update_ops)
def call(self, inputs): if self.conditional_inputs is None and self.conditional_outputs is None: covariance_matrix = self.covariance_fn(inputs, inputs) # Tile locations so output has shape [units, batch_size]. Covariance will # broadcast to [units, batch_size, batch_size], and we perform # shape manipulations to get a random variable over [batch_size, units]. loc = self.mean_fn(inputs) loc = tf.tile(loc[tf.newaxis], [self.units] + [1] * len(loc.shape)) else: knn = self.covariance_fn(inputs, inputs) knm = self.covariance_fn(inputs, self.conditional_inputs) kmm = self.covariance_fn(self.conditional_inputs, self.conditional_inputs) kmm = tf.linalg.set_diag( kmm, tf.linalg.diag_part(kmm) + tf.keras.backend.epsilon()) kmm_tril = tf.linalg.cholesky(kmm) kmm_tril_operator = tf.linalg.LinearOperatorLowerTriangular( kmm_tril) knm_operator = tf.linalg.LinearOperatorFullMatrix(knm) # TODO(trandustin): Vectorize linear algebra for multiple outputs. For # now, we do each separately and stack to obtain a locations Tensor of # shape [units, batch_size]. loc = [] for conditional_outputs_unit in tf.unstack( self.conditional_outputs, axis=-1): center = conditional_outputs_unit - self.mean_fn( self.conditional_inputs) loc_unit = knm_operator.matvec( kmm_tril_operator.solvevec( kmm_tril_operator.solvevec(center), adjoint=True)) loc.append(loc_unit) loc = tf.stack(loc) + self.mean_fn(inputs)[tf.newaxis] covariance_matrix = knn covariance_matrix -= knm_operator.matmul( kmm_tril_operator.solve(kmm_tril_operator.solve( knm, adjoint_arg=True), adjoint=True)) covariance_matrix = tf.linalg.set_diag( covariance_matrix, tf.linalg.diag_part(covariance_matrix) + tf.keras.backend.epsilon()) # Form a multivariate normal random variable with batch_shape units and # event_shape batch_size. Then make it be independent across the units # dimension. Then transpose its dimensions so it is [batch_size, units]. random_variable = ( generated_random_variables.MultivariateNormalFullCovariance( loc=loc, covariance_matrix=covariance_matrix)) random_variable = generated_random_variables.Independent( random_variable.distribution, reinterpreted_batch_ndims=1) bijector = tfp.bijectors.Inline( forward_fn=lambda x: tf.transpose(x, perm=[1, 0]), inverse_fn=lambda y: tf.transpose(y, perm=[1, 0]), forward_event_shape_fn=lambda input_shape: input_shape[::-1], forward_event_shape_tensor_fn=lambda input_shape: input_shape[::-1 ], inverse_log_det_jacobian_fn=lambda y: tf.cast(0, y.dtype), forward_min_event_ndims=2) random_variable = generated_random_variables.TransformedDistribution( random_variable.distribution, bijector=bijector) return random_variable
def compute_logits(self, batch_features: tf.Tensor, target_token_seq: tf.Tensor, training: bool) -> tf.Tensor: """ Implements a language model, where each output is conditional on the current input and inputs processed so far. Args: token_ids: int32 tensor of shape [B, T], storing integer IDs of tokens. training: Flag indicating if we are currently training (used to toggle dropout) Returns: tf.float32 tensor of shape [B, T, V], storing the distribution over output symbols for each timestep for each batch element. """ num_graphs = tf.cast(batch_features["num_graphs_in_batch"], tf.float32) if self.hyperparameters["encoder_type"] == "seq": enc_hidden = self.seq_encoder.initialize_hidden_state( batch_features["num_graphs_in_batch"]) enc_output, enc_hidden = self.seq_encoder( batch_features["source_seq"], enc_hidden) dec_hidden = enc_hidden dec_input = tf.expand_dims( [self.vocab_target.get_id_or_unk("%START%")] * batch_features["num_graphs_in_batch"], 1) elif self.hyperparameters["encoder_type"] == "graph": enc_hidden, enc_output = self.graph_encoder(batch_features, training=training) enc_output = tf.split(enc_output, batch_features["graph_to_num_nodes"]) enc_output = [ tf.concat([ out, tf.zeros([ self.hyperparameters["max_node_num"] - out.shape[0], self.hyperparameters["token_embedding_size"] ]) ], 0) for out in enc_output ] enc_output = tf.stack(enc_output) dec_hidden_h = self.linear_h(enc_hidden) dec_hidden_c = self.linear_c(enc_hidden) if self.hyperparameters["rnn_cell"] == "LSTM": dec_hidden = [dec_hidden_h, dec_hidden_c] else: dec_hidden = dec_hidden_h dec_input = tf.expand_dims( [self.vocab_target.get_id_or_unk("%START%")] * enc_hidden.shape[0], 1) elif self.hyperparameters["encoder_type"] == "graph+seq": enc_hidden, enc_output = self.graph_encoder(batch_features, training=training) enc_output = tf.split(enc_output, batch_features["graph_to_num_nodes"]) enc_output = [ tf.concat([ out[:sorce_len, :], tf.zeros([ 200 - out[:sorce_len, :].shape[0], self.hyperparameters["token_embedding_size"] ]) ], 0) for out, sorce_len in zip(enc_output, batch_features["source_len"]) ] enc_output = tf.stack(enc_output) hidden_h = self.linear_h(enc_hidden) hidden_c = self.linear_c(enc_hidden) if self.hyperparameters["rnn_cell"] == "LSTM": hidden = [hidden_h, hidden_c] else: hidden = hidden_h enc_output, dec_hidden = self.seq_encoder(enc_output, hidden, embedd=False) dec_input = tf.expand_dims( [self.vocab_target.get_id_or_unk("%START%")] * enc_hidden.shape[0], 1) elif self.hyperparameters["encoder_type"] == "seq+graph": enc_hidden = self.seq_encoder.initialize_hidden_state( batch_features["num_graphs_in_batch"]) seq_enc_output, seq_enc_hidden = self.seq_encoder( batch_features["source_seq"], enc_hidden) enc_hidden, enc_output = self.graph_encoder( batch_features, training=training, seq_enc_output=seq_enc_output) enc_output = tf.split(enc_output, batch_features["graph_to_num_nodes"]) enc_output = [ tf.concat([ out, tf.zeros([ self.hyperparameters["max_node_num"] - out.shape[0], self.hyperparameters["token_embedding_size"] ]) ], 0) for out in enc_output ] enc_output = tf.stack(enc_output) hidden_h = self.linear_h(enc_hidden) if self.hyperparameters["rnn_cell"] == "LSTM": dec_hidden = [hidden_h, seq_enc_hidden[1]] else: dec_hidden = hidden_h dec_input = tf.expand_dims( [self.vocab_target.get_id_or_unk("%START%")] * enc_hidden.shape[0], 1) if training and random.random() > 0.5: # Use teacher forcing predictions, dec_hidden = self.decoder(target_token_seq[:, :-1], dec_hidden, enc_output) return predictions else: # The predicted ID is fed back into the model for t in range(1, self.hyperparameters["max_seq_length"]): predictions, dec_hidden = self.decoder(dec_input, dec_hidden, enc_output) predicted_ids = tf.argmax(predictions[:, 0, :], 1) dec_input = tf.expand_dims(predicted_ids, 1) new_logits = tf.expand_dims(predictions[:, 0, :], 1) if t == 1: results = new_logits else: results = tf.concat([results, new_logits], 1) return results
def __init__(self, num_timesteps, coefficients, level_scale, initial_state_prior, observation_noise_scale=0., name=None, **linear_gaussian_ssm_kwargs): """Build a state space model implementing an autoregressive process. Args: num_timesteps: Scalar `int` `Tensor` number of timesteps to model with this distribution. coefficients: `float` `Tensor` of shape `concat(batch_shape, [order])` defining the autoregressive coefficients. The coefficients are defined backwards in time: `coefficients[0] * level[t] + coefficients[1] * level[t-1] + ... + coefficients[order-1] * level[t-order+1]`. level_scale: Scalar (any additional dimensions are treated as batch dimensions) `float` `Tensor` indicating the standard deviation of the transition noise at each step. initial_state_prior: instance of `tfd.MultivariateNormal` representing the prior distribution on latent states. Must have event shape `[order]`. observation_noise_scale: Scalar (any additional dimensions are treated as batch dimensions) `float` `Tensor` indicating the standard deviation of the observation noise. Default value: 0. name: Python `str` name prefixed to ops created by this class. Default value: "AutoregressiveStateSpaceModel". **linear_gaussian_ssm_kwargs: Optional additional keyword arguments to to the base `tfd.LinearGaussianStateSpaceModel` constructor. """ parameters = dict(locals()) parameters.update(linear_gaussian_ssm_kwargs) del parameters['linear_gaussian_ssm_kwargs'] with tf.name_scope(name or 'AutoregressiveStateSpaceModel') as name: # The initial state prior determines the dtype of sampled values. # Other model parameters must have the same dtype. dtype = initial_state_prior.dtype coefficients = tf.convert_to_tensor(value=coefficients, name='coefficients', dtype=dtype) level_scale = tf.convert_to_tensor(value=level_scale, name='level_scale', dtype=dtype) observation_noise_scale = tf.convert_to_tensor( value=observation_noise_scale, name='observation_noise_scale', dtype=dtype) order = tf.compat.dimension_value(coefficients.shape[-1]) if order is None: raise ValueError( 'Autoregressive coefficients must have static shape.') self._order = order self._coefficients = coefficients self._level_scale = level_scale super(AutoregressiveStateSpaceModel, self).__init__( num_timesteps=num_timesteps, transition_matrix=make_ar_transition_matrix(coefficients), transition_noise=tfd.MultivariateNormalDiag( scale_diag=tf.stack([level_scale] + [tf.zeros_like(level_scale)] * (self.order - 1), axis=-1)), observation_matrix=tf.concat([ tf.ones([1, 1], dtype=dtype), tf.zeros([1, self.order - 1], dtype=dtype) ], axis=-1), observation_noise=tfd.MultivariateNormalDiag( scale_diag=observation_noise_scale[..., tf.newaxis]), initial_state_prior=initial_state_prior, name=name, **linear_gaussian_ssm_kwargs) self._parameters = parameters
def softquantiles(x, quantiles, quantile_width=None, axis=-1, may_squeeze=True, **kwargs): """Computes soft quantiles via optimal transport. This operator takes advantage of the fact that an exhaustive softsort is not required to recover a single quantile. Instead, one can transport all input values in x onto only 3 weighted values. Target weights are adjusted so that those values in x that are transported to the middle value in the target vector y correspond to those concentrating around the quantile of interest. This idea generalizes to more quantiles, interleaving small weights on the quantile indices and bigger weights in between, corresponding to the gap from one desired quantile to the next one. Args: x: Tensor<float> of any shape. quantiles: list<float> the quantiles to be returned. It can also be a single float. quantile_width: (float) mass given to the bucket supposed to attract points whose value concentrate around the desired quantile value. Bigger width means that we allow the soft quantile to be a mixture of more points further away from the quantile. If None, the width is set at 1/n where n is the number of values considered (the size along the 'axis'). axis: (int) the axis along which to compute the quantile. may_squeeze: (bool) should we squeeze the output tensor in case of a single quantile. **kwargs: see SoftQuantilizer for possible extra parameters. Returns: A Tensor<float> similar to the input tensor, but the axis dimension is replaced by the number of quantiles specified in the quantiles list. Hence, if only a quantile is requested (quantiles is a float) only one value in that axis is returned. When several quantiles are requested, the tensor will have that many values in that axis. Raises: tf.errors.InvalidArgumentError when the quantiles and quantile width are not correct, namely quantiles are either not in sorted order or the quantile_width is too large. """ if isinstance(quantiles, float): quantiles = [quantiles] quantiles = tf.constant(quantiles, tf.float32) # Preprocesses submitted quantiles to check that they satisfy elementary # constraints. valid_quantiles = tf.boolean_mask( quantiles, tf.logical_and(quantiles > 0.0, quantiles < 1.0)) num_quantiles = tf.shape(valid_quantiles)[0] # Includes values on both ends of [0,1]. extended_quantiles = tf.concat([[0.0], valid_quantiles, [1.0]], axis=0) # Builds filler_weights in between the target quantiles. filler_weights = extended_quantiles[1:] - extended_quantiles[:-1] if quantile_width is None: quantile_width = tf.reduce_min( tf.concat([ filler_weights, [1.0 / tf.cast(tf.shape(x)[axis], dtype=x.dtype)] ], axis=0)) # Takes into account quantile_width in the definition of weights shift = -tf.ones(tf.shape(filler_weights), dtype=x.dtype) shift = shift + 0.5 * (tf.one_hot(0, num_quantiles + 1) + tf.one_hot(num_quantiles, num_quantiles + 1)) filler_weights = filler_weights + quantile_width * shift assert_op = tf.Assert(tf.reduce_all(filler_weights >= 0.0), [filler_weights]) with tf.control_dependencies([assert_op]): # Adds one more value to have tensors of the same shape to interleave them. quantile_weights = tf.ones(num_quantiles + 1) * quantile_width # Interleaves the filler_weights with the quantile weights. weights = tf.reshape( tf.stack([filler_weights, quantile_weights], axis=1), (-1, ))[:-1] # Sends only the positive weights to the softsort operator. positive_weights = tf.boolean_mask(weights, weights > 0.0) all_quantiles = softsort(x, direction='ASCENDING', axis=axis, target_weights=positive_weights, **kwargs) # Recovers the indices corresponding to the desired quantiles. odds = tf.math.floormod(tf.range(weights.shape[0], dtype=tf.float32), 2) positives = tf.cast(weights > 0.0, tf.float32) indices = tf.cast(tf.math.cumsum(positives) * odds, dtype=tf.int32) indices = tf.boolean_mask(indices, indices > 0) - 1 result = tf.gather(all_quantiles, indices, axis=axis) # In the specific case where we want a single quantile, squeezes the # quantile dimension. can_squeeze = tf.equal(tf.shape(result)[axis], 1) if tf.math.logical_and(can_squeeze, may_squeeze): result = tf.squeeze(result, axis=axis) return result
def plot_transfer(self, intervals, figsize=None, soft_round=None, **kwargs): if not len(intervals) == self.ndim_source == self.ndim_latent == 1: raise ValueError("This method is only defined for 1D models.") if soft_round is None: soft_round = self.soft_round[1] x = [ tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in intervals ] x = tf.meshgrid(*x, indexing="ij") x = tf.stack(x, axis=-1) y_hat, _, _ = self.encode_decode(x, False, False, soft_round, **kwargs) y = self.analysis(x) # We feed y here so we can visualize the full behavior of the synthesis # transform (not just at the quantized latent values). x_hat = self.synthesis(y) x = np.squeeze(x.numpy(), -1) y = np.squeeze(y.numpy(), -1) x_hat = np.squeeze(x_hat.numpy(), -1) y_hat = np.squeeze(y_hat.numpy(), -1) ylim = np.min(y), np.max(y) boundaries = np.nonzero(y_hat[1:] != y_hat[:-1])[0] lboundaries = (y_hat[boundaries] + y_hat[boundaries + 1]) / 2 dboundaries = (x[boundaries] + x[boundaries + 1]) / 2 lcodebook = np.unique(y_hat) dcodebook = self.synthesis(lcodebook[:, None]).numpy() dcodebook = np.squeeze(dcodebook, -1) mask = np.logical_and(ylim[0] < lcodebook, lcodebook < ylim[1]) lcodebook = lcodebook[mask] dcodebook = dcodebook[mask] plt.figure(figsize=figsize or (16, 14)) plt.plot(x, y, label="analysis transform") plt.plot(x_hat, y, label="synthesis transform") plt.gca().set_aspect("equal", "box") # Flip y axis if latent space is reversed. if y[0] > y[-1]: plt.gca().invert_yaxis() plt.xticks(dcodebook) plt.yticks(lcodebook) plt.grid(False) plt.xlabel("source space") plt.ylabel("latent space") xmin = plt.axis()[0] ymin = plt.axis()[2] for x, y in zip(dcodebook, lcodebook): plt.plot([xmin, x, x], [y, y, ymin], "black", lw=1) plt.plot([x], [y], "black", marker="o", ms=5, lw=1, label="codebook" if x == dcodebook[0] else None) for x, y in zip(dboundaries, lboundaries): plt.plot([xmin, x, x], [y, y, ymin], "black", lw=1, ls=":") plt.plot([x], [y], "black", marker="o", ms=3, lw=1, ls=":", label="boundaries" if x == dboundaries[0] else None) plt.legend(loc="upper left")
def plot_jacobians(self, which, intervals, arrow_intervals, scale=2, figsize=None): if not (len(intervals) == len(arrow_intervals) == self.ndim_source == self.ndim_latent == 2): raise ValueError("This method is only defined for 2D models.") if which not in ("analysis", "synthesis"): raise ValueError("`which` must be 'analysis' or 'synthesis'.") data = [ tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in intervals ] data = tf.meshgrid(*data, indexing="ij") data = tf.stack(data, axis=-1) data_dist = self.source.prob(data).numpy() if which == "analysis": arrow_data = [ tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in arrow_intervals ] arrow_data = tf.meshgrid(*arrow_data, indexing="ij") arrow_data = tf.stack(arrow_data, axis=-1) arrow_data = tf.reshape(arrow_data, (-1, arrow_data.shape[-1])) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(arrow_data) arrow_latents = self.analysis(arrow_data) # First dimension is batch, second is latent dim, third is source dim. jacobian = tape.batch_jacobian(arrow_latents, arrow_data) jacobian = tf.linalg.inv(jacobian) jacobian = tf.transpose(jacobian, (0, 2, 1)) else: arrow_latents = [ tf.linspace(float(i[0]), float(i[1]), int(i[2])) for i in arrow_intervals ] arrow_latents = tf.meshgrid(*arrow_latents, indexing="ij") arrow_latents = tf.stack(arrow_latents, axis=-1) arrow_latents = tf.reshape(arrow_latents, (-1, arrow_latents.shape[-1])) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(arrow_latents) arrow_data = self.synthesis(arrow_latents) jacobian = tape.batch_jacobian(arrow_data, arrow_latents) jacobian = tf.transpose(jacobian, (0, 2, 1)) google_pink = (0xf4 / 255, 0x39 / 255, 0xa0 / 255) google_purple = (0xa1 / 255, 0x42 / 255, 0xf4 / 255) plt.figure(figsize=figsize or (16, 14)) plt.imshow(data_dist, vmin=0, vmax=data_dist.max(), origin="lower", extent=(data[0, 0, 1], data[0, -1, 1], data[0, 0, 0], data[-1, 0, 0])) plt.quiver( arrow_data[:, 1], arrow_data[:, 0], jacobian[:, 0, 1], jacobian[:, 0, 0], pivot="tail", angles="xy", headlength=4, headaxislength=4, units="dots", color=google_pink, scale_units="xy", scale=scale, ) plt.quiver( arrow_data[:, 1], arrow_data[:, 0], jacobian[:, 1, 1], jacobian[:, 1, 0], pivot="tail", angles="xy", headlength=4, headaxislength=4, units="dots", color=google_purple, scale_units="xy", scale=scale, ) plt.axis("image") plt.grid(False) plt.xlim(data[0, 0, 1], data[0, -1, 1]) plt.ylim(data[0, 0, 0], data[-1, 0, 0]) plt.xlabel("source dimension 1") plt.ylabel("source dimension 2")
def _stack(*ts): return tf.stack(_conform(ts), axis=-1)
def loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): batch_shape = prefer_static.shape( current_step_meta_info.init_energy) direction = tf.cast(tf.random.uniform(shape=batch_shape, minval=0, maxval=2, dtype=tf.int32, seed=self._seed_stream()), dtype=tf.bool) tree_start_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[1], v[0]), initial_step_state) directions_expanded = [ _rightmost_expand_to_rank(direction, prefer_static.rank(state)) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(d, ss, -ss) for d, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, momentum_tree_cumsum, leapfrogs_taken ] = self._build_sub_tree( directions_expanded, integrator, current_step_meta_info, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory) last_candidate_state = initial_step_metastate.candidate_state tree_weight = candidate_tree_state.weight if MULTINOMIAL_SAMPLE: weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(choose_new_state, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], energy=tf.where( _rightmost_expand_to_rank( choose_new_state, prefer_static.rank(candidate_tree_state.target)), candidate_tree_state.energy, last_candidate_state.energy), weight=weight_sum) # Update left right information of the trajectory, and check trajectory # level U turn tree_otherend_states = tf.nest.map_structure( lambda v: tf.where( # pylint: disable=g-long-lambda _rightmost_expand_to_rank( direction, prefer_static.rank(v[1])), v[0], v[1]), initial_step_state) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ tf.stack( [ # pylint: disable=g-complex-comprehension tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), r, l), tf.where( _rightmost_expand_to_rank( direction, prefer_static.rank(l)), l, r), ], axis=0) for l, r in zip(tf.nest.flatten(tree_final_states), tf.nest.flatten(tree_otherend_states)) ]) if GENERALIZED_UTURN: state_diff = momentum_tree_cumsum else: state_diff = [s[1] - s[0] for s in new_step_state.state] no_u_turns_trajectory = has_not_u_turn( state_diff, [m[0] for m in new_step_state.momentum], [m[1] for m in new_step_state.momentum], log_prob_rank=len(batch_shape)) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, energy_diff_sum=(energy_diff_tree_sum + initial_step_metastate.energy_diff_sum), continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, new_step_state, new_step_metastate
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None): """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`. Args: lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`. perm: `p` as returned by `tf.linag.lu`, i.e., if `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`. validate_args: Python `bool` indicating whether arguments should be checked for correctness. Default value: `False` (i.e., don't validate arguments). name: Python `str` name given to ops managed by this object. Default value: `None` (i.e., 'lu_reconstruct'). Returns: x: The original input to `tf.linalg.lu`, i.e., `x` as in, `lu_reconstruct(*tf.linalg.lu(x))`. #### Examples ```python import numpy as np import tensorflow as tf import tensorflow_probability as tfp x = [[[3., 4], [1, 2]], [[7., 8], [3, 4]]] x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x)) tf.assert_near(x, x_reconstructed) # ==> True ``` """ with tf.name_scope(name or 'lu_reconstruct'): lower_upper = tf.convert_to_tensor(lower_upper, dtype_hint=tf.float32, name='lower_upper') perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm') assertions = lu_reconstruct_assertions(lower_upper, perm, validate_args) if assertions: with tf.control_dependencies(assertions): lower_upper = tf.identity(lower_upper) perm = tf.identity(perm) shape = tf.shape(lower_upper) lower = tf.linalg.set_diag( tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0), tf.ones(shape[:-1], dtype=lower_upper.dtype)) upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1) x = tf.matmul(lower, upper) if (tensorshape_util.rank(lower_upper.shape) is None or tensorshape_util.rank(lower_upper.shape) != 2): # We either don't know the batch rank or there are >0 batch dims. batch_size = tf.reduce_prod(shape[:-2]) d = shape[-1] x = tf.reshape(x, [batch_size, d, d]) perm = tf.reshape(perm, [batch_size, d]) perm = tf.map_fn(tf.math.invert_permutation, perm) batch_indices = tf.broadcast_to( tf.range(batch_size)[:, tf.newaxis], [batch_size, d]) x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1)) x = tf.reshape(x, shape) else: x = tf.gather(x, tf.math.invert_permutation(perm)) tensorshape_util.set_shape(x, lower_upper.shape) return x
def _sample_n(self, n, seed=None): loc, scale, low, high = self._loc_scale_low_high() batch_shape = self._batch_shape_tensor( loc=loc, scale=scale, low=low, high=high) sample_and_batch_shape = tf.concat([[n], batch_shape], 0) flat_batch_and_sample_shape = tf.stack([tf.reduce_prod(batch_shape), n]) # TODO(b/162522020): Use this behavior unconditionally. if (tf.executing_eagerly() or not control_flow_util.GraphOrParentsInXlaContext( tf1.get_default_graph())): return tf.random.stateless_parameterized_truncated_normal( shape=sample_and_batch_shape, means=loc, stddevs=scale, minvals=low, maxvals=high, seed=samplers.sanitize_seed(seed)) # In order to be reparameterizable we sample on the truncated_normal of # unit variance and mean and scale (but with the standardized # truncation bounds). @tf.custom_gradient def _std_samples_with_gradients(lower, upper): """Standard truncated Normal with gradient support for low, high.""" # Note: Unlike the convention in TFP, parameterized_truncated_normal # returns a tensor with the final dimension being the sample dimension. std_samples = random_ops.parameterized_truncated_normal( shape=flat_batch_and_sample_shape, means=0.0, stddevs=1.0, minvals=lower, maxvals=upper, dtype=self.dtype, seed=seed) def grad(dy): """Computes a derivative for the min and max parameters. This function implements the derivative wrt the truncation bounds, which get blocked by the sampler. We use a custom expression for numerical stability instead of automatic differentiation on CDF for implicit gradients. Args: dy: output gradients Returns: The standard normal samples and the gradients wrt the upper bound and lower bound. """ # std_samples has an extra dimension (the sample dimension), expand # lower and upper so they broadcast along this dimension. # See note above regarding parameterized_truncated_normal, the sample # dimension is the final dimension. lower_broadcast = lower[..., tf.newaxis] upper_broadcast = upper[..., tf.newaxis] cdf_samples = ((special_math.ndtr(std_samples) - special_math.ndtr(lower_broadcast)) / (special_math.ndtr(upper_broadcast) - special_math.ndtr(lower_broadcast))) # tiny, eps are tolerance parameters to ensure we stay away from giving # a zero arg to the log CDF expression. tiny = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).tiny eps = np.finfo(dtype_util.as_numpy_dtype(self.dtype)).eps cdf_samples = tf.clip_by_value(cdf_samples, tiny, 1 - eps) du = tf.exp(0.5 * (std_samples**2 - upper_broadcast**2) + tf.math.log(cdf_samples)) dl = tf.exp(0.5 * (std_samples**2 - lower_broadcast**2) + tf.math.log1p(-cdf_samples)) # Reduce the gradient across the samples grad_u = tf.reduce_sum(dy * du, axis=-1) grad_l = tf.reduce_sum(dy * dl, axis=-1) return [grad_l, grad_u] return std_samples, grad std_low, std_high = self._standardized_low_and_high( low=low, high=high, loc=loc, scale=scale) low_high_shp = tf.broadcast_dynamic_shape( tf.shape(std_low), tf.shape(std_high)) std_low = tf.broadcast_to(std_low, low_high_shp) std_high = tf.broadcast_to(std_high, low_high_shp) std_samples = _std_samples_with_gradients( tf.reshape(std_low, [-1]), tf.reshape(std_high, [-1])) # The returned shape is [flat_batch x n] std_samples = tf.transpose(std_samples, perm=[1, 0]) std_samples = tf.reshape(std_samples, sample_and_batch_shape) return std_samples * scale[tf.newaxis] + loc[tf.newaxis]
def __init__(self, num_timesteps, level_scale, slope_scale, initial_state_prior, observation_noise_scale=0., initial_step=0, validate_args=False, allow_nan_stats=True, name=None): """Build a state space model implementing a local linear trend. Args: num_timesteps: Scalar `int` `Tensor` number of timesteps to model with this distribution. level_scale: Scalar (any additional dimensions are treated as batch dimensions) `float` `Tensor` indicating the standard deviation of the level transitions. slope_scale: Scalar (any additional dimensions are treated as batch dimensions) `float` `Tensor` indicating the standard deviation of the slope transitions. initial_state_prior: instance of `tfd.MultivariateNormal` representing the prior distribution on latent states; must have event shape `[2]`. observation_noise_scale: Scalar (any additional dimensions are treated as batch dimensions) `float` `Tensor` indicating the standard deviation of the observation noise. initial_step: Optional scalar `int` `Tensor` specifying the starting timestep. Default value: 0. validate_args: Python `bool`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. Default value: `False`. allow_nan_stats: Python `bool`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. Default value: `True`. name: Python `str` name prefixed to ops created by this class. Default value: "LocalLinearTrendStateSpaceModel". """ with tf1.name_scope(name, 'LocalLinearTrendStateSpaceModel', [level_scale, slope_scale]) as name: # The initial state prior determines the dtype of sampled values. # Other model parameters must have the same dtype. dtype = initial_state_prior.dtype level_scale = tf.convert_to_tensor( value=level_scale, name='level_scale', dtype=dtype) slope_scale = tf.convert_to_tensor( value=slope_scale, name='slope_scale', dtype=dtype) observation_noise_scale = tf.convert_to_tensor( value=observation_noise_scale, name='observation_noise_scale', dtype=dtype) # Explicitly broadcast all parameters to the same batch shape. This # allows us to use `tf.stack` for a compact model specification. broadcast_batch_shape = dist_util.get_broadcast_shape( level_scale, slope_scale) broadcast_ones = tf.ones(broadcast_batch_shape, dtype=dtype) self._level_scale = level_scale self._slope_scale = slope_scale self._observation_noise_scale = observation_noise_scale # Construct a linear Gaussian state space model implementing the # local linear trend model. See "Mathematical Details" in the # class docstring for further explanation. super(LocalLinearTrendStateSpaceModel, self).__init__( num_timesteps=num_timesteps, transition_matrix=tf.constant( [[1., 1.], [0., 1.]], dtype=dtype, name='transition_matrix'), transition_noise=tfd.MultivariateNormalDiag( scale_diag=tf.stack( [level_scale * broadcast_ones, slope_scale * broadcast_ones], axis=-1), name='transition_noise'), observation_matrix=tf.constant( [[1., 0.]], dtype=dtype, name='observation_matrix'), observation_noise=tfd.MultivariateNormalDiag( scale_diag=observation_noise_scale[..., tf.newaxis], name='observation_noise'), initial_state_prior=initial_state_prior, initial_step=initial_step, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name)
def evaluate( env, policy, num_episodes = 10, ctx_length = None, embed_training_window = None, state_mask_fn = None, # pylint: disable=g-bare-generic ): """Evaluates the policy. Args: env: Environment to evaluate the policy on. policy: Policy to evaluate. num_episodes: A number of episodes to average the policy on. ctx_length: number of previous steps to compute context from. embed_training_window: window size used during embed training. state_mask_fn: state masking function for partially obs envs. Returns: Averaged reward and a total number of steps. """ total_timesteps = 0 total_returns = 0.0 def apply_mask(observation): if state_mask_fn: return tf.convert_to_tensor(state_mask_fn(observation.numpy())) return observation for _ in range(num_episodes): timestep = env.reset() if ctx_length: states = [apply_mask(timestep.observation) for _ in range(ctx_length)] actions = [ tf.zeros(policy.action_spec.shape)[None, :] for _ in range(ctx_length) ] rewards = [[0.] for _ in range(ctx_length)] latent_action = None i = 0 while not timestep.is_last(): if embed_training_window and (i % embed_training_window == 0 or embed_training_window <= 2): latent_action = None if ctx_length: states.append(apply_mask(timestep.observation)) if len(states) > ctx_length: states.pop(0) actions.pop(0) rewards.pop(0) action = policy.act( tf.stack(states, axis=1), actions=tf.stack(actions, axis=1), rewards=tf.stack(rewards, axis=1)) actions.append(action) else: if embed_training_window: action, latent_action = policy.act( apply_mask(timestep.observation), latent_action=latent_action) else: action = policy.act(apply_mask(timestep.observation)) timestep = env.step(action) if ctx_length: rewards.append(timestep.reward) total_returns += timestep.reward[0] total_timesteps += 1 i += 1 return total_returns / num_episodes, total_timesteps / num_episodes
def __init__(self, level_scale_prior=None, slope_scale_prior=None, initial_level_prior=None, initial_slope_prior=None, observed_time_series=None, name=None): """Specify a local linear trend model. Args: level_scale_prior: optional `tfd.Distribution` instance specifying a prior on the `level_scale` parameter. If `None`, a heuristic default prior is constructed based on the provided `observed_time_series`. Default value: `None`. slope_scale_prior: optional `tfd.Distribution` instance specifying a prior on the `slope_scale` parameter. If `None`, a heuristic default prior is constructed based on the provided `observed_time_series`. Default value: `None`. initial_level_prior: optional `tfd.Distribution` instance specifying a prior on the initial level. If `None`, a heuristic default prior is constructed based on the provided `observed_time_series`. Default value: `None`. initial_slope_prior: optional `tfd.Distribution` instance specifying a prior on the initial slope. If `None`, a heuristic default prior is constructed based on the provided `observed_time_series`. Default value: `None`. observed_time_series: optional `float` `Tensor` of shape `batch_shape + [T, 1]` (omitting the trailing unit dimension is also supported when `T > 1`), specifying an observed time series. Any priors not explicitly set will be given default values according to the scale of the observed time series (or batch of time series). May optionally be an instance of `tfp.sts.MaskedTimeSeries`, which includes a mask `Tensor` to specify timesteps with missing observations. Default value: `None`. name: the name of this model component. Default value: 'LocalLinearTrend'. """ with tf1.name_scope( name, 'LocalLinearTrend', values=[observed_time_series]) as name: _, observed_stddev, observed_initial = ( sts_util.empirical_statistics(observed_time_series) if observed_time_series is not None else (0., 1., 0.)) # Heuristic default priors. Overriding these may dramatically # change inference performance and results. if level_scale_prior is None: level_scale_prior = tfd.LogNormal( loc=tf.math.log(.05 * observed_stddev), scale=3., name='level_scale_prior') if slope_scale_prior is None: slope_scale_prior = tfd.LogNormal( loc=tf.math.log(.05 * observed_stddev), scale=3., name='slope_scale_prior') if initial_level_prior is None: initial_level_prior = tfd.Normal( loc=observed_initial, scale=tf.abs(observed_initial) + observed_stddev, name='initial_level_prior') if initial_slope_prior is None: initial_slope_prior = tfd.Normal( loc=0., scale=observed_stddev, name='initial_slope_prior') tf.debugging.assert_same_float_dtype([ level_scale_prior, slope_scale_prior, initial_level_prior, initial_slope_prior ]) self._initial_state_prior = tfd.MultivariateNormalDiag( loc=tf.stack( [initial_level_prior.mean(), initial_slope_prior.mean() ], axis=-1), scale_diag=tf.stack([ initial_level_prior.stddev(), initial_slope_prior.stddev() ], axis=-1)) scaled_softplus = tfb.Chain([tfb.AffineScalar(scale=observed_stddev), tfb.Softplus()]) super(LocalLinearTrend, self).__init__( parameters=[ Parameter('level_scale', level_scale_prior, scaled_softplus), Parameter('slope_scale', slope_scale_prior, scaled_softplus) ], latent_size=2, name=name)
def test_forward_gradient(self): t = tf.range(1, 3, dtype=tf.float32) # Shape [2] func = lambda t: tf.stack([t, t**2, t**3], axis=0) # Shape [3, 2] fwd_grad = self.evaluate(math.fwd_gradient(func, t)) self.assertEqual(fwd_grad.shape, (3, 2)) np.testing.assert_allclose(fwd_grad, [[1., 1.], [2., 4.], [3., 12.]])
def _batch_interp_with_gather_nd(x, x_ref_min, x_ref_max, y_ref, nd, fill_value, batch_dims): """N-D interpolation that works with leading batch dims.""" dtype = x.dtype # In this function, # x.shape = [A1, ..., An, D, nd], where n = batch_dims # and # y_ref.shape = [A1, ..., An, C1, C2,..., Cnd, B1,...,BM] # y_ref[A1, ..., An, i1,...,ind] is a shape [B1,...,BM] Tensor with the value # at index [i1,...,ind] in the interpolation table. # and x_ref_max have shapes [A1, ..., An, nd]. # ny[k] is number of y reference points in interp dim k. ny = tf.cast(tf.shape(y_ref)[batch_dims:batch_dims + nd], dtype) # Map [x_ref_min, x_ref_max] to [0, ny - 1]. # This is the (fractional) index of x. # x_idx_unclipped[A1, ..., An, d, k] is the fractional index into dim k of # interpolation table for the dth x value. x_ref_min_expanded = tf.expand_dims(x_ref_min, axis=-2) x_ref_max_expanded = tf.expand_dims(x_ref_max, axis=-2) x_idx_unclipped = (ny - 1) * (x - x_ref_min_expanded) / ( x_ref_max_expanded - x_ref_min_expanded) # Wherever x is NaN, x_idx_unclipped will be NaN as well. # Keep track of the nan indices here (so we can impute NaN later). # Also eliminate any NaN indices, since there is not NaN in 32bit. nan_idx = tf.math.is_nan(x_idx_unclipped) x_idx_unclipped = tf.where(nan_idx, 0., x_idx_unclipped) # x_idx.shape = [A1, ..., An, D, nd] x_idx = tf.clip_by_value(x_idx_unclipped, tf.zeros((), dtype=dtype), ny - 1) # Get the index above and below x_idx. # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx), # however, this results in idx_below == idx_above whenever x is on a grid. # This in turn results in y_ref_below == y_ref_above, and then the gradient # at this point is zero. So here we 'jitter' one of idx_below, idx_above, # so that they are at different values. This jittering does not affect the # interpolated value, but does make the gradient nonzero (unless of course # the y_ref values are the same). idx_below = tf.floor(x_idx) idx_above = tf.minimum(idx_below + 1, ny - 1) idx_below = tf.maximum(idx_above - 1, 0) # These are the values of y_ref corresponding to above/below indices. # idx_below_int32.shape = x.shape[:-1] + [nd] idx_below_int32 = tf.cast(idx_below, dtype=tf.int32) idx_above_int32 = tf.cast(idx_above, dtype=tf.int32) # idx_below_list is a length nd list of shape x.shape[:-1] int32 tensors. idx_below_list = tf.unstack(idx_below_int32, axis=-1) idx_above_list = tf.unstack(idx_above_int32, axis=-1) # Use t to get a convex combination of the below/above values. # t.shape = [A1, ..., An, D, nd] t = x_idx - idx_below # x, and tensors shaped like x, need to be added to, and selected with # (using tf.where) the output y. This requires appending singletons. def _expand_x_fn(tensor): # Reshape tensor to tensor.shape + [1] * M. extended_shape = tf.concat([ tf.shape(tensor), tf.ones_like(tf.shape(y_ref)[batch_dims + nd:]) ], axis=0) return tf.reshape(tensor, extended_shape) # Now, t.shape = [A1, ..., An, D, nd] + [1] * (rank(y_ref) - nd - batch_dims) t = _expand_x_fn(t) s = 1 - t # Re-insert NaN wherever x was NaN. nan_idx = _expand_x_fn(nan_idx) t = tf.where(nan_idx, tf.constant(np.nan, dtype), t) terms = [] # Our work above has located x's fractional index inside a cube of above/below # indices. The distance to the below indices is t, and to the above indices # is s. # Drawing lines from x to the cube walls, we get 2**nd smaller cubes. Each # term in the result is a product of a reference point, gathered from y_ref, # multiplied by a volume. The volume is that of the cube opposite to the # reference point. E.g. if the reference point is below x in every axis, the # volume is that of the cube with corner above x in every axis, s[0]*...*s[nd] # We could probably do this with one massive gather, but that would be very # unreadable and un-debuggable. It also would create a large Tensor. for zero_ones_list in _binary_count(nd): gather_from_y_ref_idx = [] opposite_volume_t_idx = [] opposite_volume_s_idx = [] for k, zero_or_one in enumerate(zero_ones_list): if zero_or_one == 0: # If the kth iterate has zero_or_one = 0, # Will gather from the 'below' reference point along axis k. gather_from_y_ref_idx.append(idx_below_list[k]) # Now append the index to gather for computing opposite_volume. # This could be done by initializing opposite_volume to 1, then here: # opposite_volume *= tf.gather(s, indices=k, axis=tf.rank(x) - 1) # but that puts a gather in the 'inner loop.' Better to append the # index and do one larger gather down below. opposite_volume_s_idx.append(k) else: gather_from_y_ref_idx.append(idx_above_list[k]) # Append an index to gather, having the same effect as # opposite_volume *= tf.gather(t, indices=k, axis=tf.rank(x) - 1) opposite_volume_t_idx.append(k) # Compute opposite_volume (volume of cube opposite the ref point): # Recall t.shape = s.shape = [D, nd] + [1, ..., 1] # Gather from t and s along the 'nd' axis, which is rank(x) - 1. ov_axis = tf.rank(x) - 1 opposite_volume = (tf.reduce_prod( tf.gather(t, indices=tf.cast(opposite_volume_t_idx, dtype=tf.int32), axis=ov_axis), axis=ov_axis) * tf.reduce_prod(tf.gather( s, indices=tf.cast(opposite_volume_s_idx, dtype=tf.int32), axis=ov_axis), axis=ov_axis)) # pyformat: disable y_ref_pt = tf.gather_nd(y_ref, tf.stack(gather_from_y_ref_idx, axis=-1), batch_dims=batch_dims) terms.append(y_ref_pt * opposite_volume) y = tf.math.add_n(terms) if tf.debugging.is_numeric_tensor(fill_value): # Recall x_idx_unclipped.shape = [D, nd], # so here we check if it was out of bounds in any of the nd dims. # Thus, oob_idx.shape = [D]. oob_idx = tf.reduce_any( (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1), axis=-1) # Now, y.shape = [D, B1,...,BM], so we'll have to broadcast oob_idx. oob_idx = _expand_x_fn(oob_idx) # Shape [D, 1,...,1] oob_idx |= tf.fill(tf.shape(y), False) y = tf.where(oob_idx, fill_value, y) return y
def collater_fn(batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]: """Collater function for relation classification task. See BaseTask.""" def flatten_bsz(tensor): return tf.reshape(tensor, [bsz]) new_batch = { 'text_ids': batch['text_ids'], 'text_mask': batch['text_mask'], 'classifier_target': flatten_bsz(batch['target']), } # Sample mentions across batch # We want to make sure that the subject / object mentions always have # priority when we sample `max_batch_mentions` out of all available # mentions. Additionally, we want these subject / object mentions to be # in the same order as their samples. In other words, we want the first # sampled mention to be object mention from the first sample, the second # sampled mention to be subject mention from the first sample, the third # sampled mention to be object mention from the second sample, etc. subj_index = flatten_bsz(batch['subject_mention_indices']) obj_index = flatten_bsz(batch['object_mention_indices']) # Adjust subject / object mention positions in individual samples to their # positions in flattened mentions. shift = tf.range( bsz, dtype=obj_index.dtype) * config.max_mentions_per_sample mention_target_indices = tf.reshape( tf.stack([subj_index + shift, obj_index + shift], axis=1), [-1]) # Sample the rest of the mentions uniformly across batch scores = tf.random.uniform(shape=tf.shape(batch['mention_mask'])) scores = scores * tf.cast(batch['mention_mask'], tf.float32) # We want to adjust scores for target mentions so they don't get sampled # for the second time. We achive this by making their scores negative. def set_negative_scores(scores, indices): indices_2d = tf.stack( [tf.range(bsz, dtype=indices.dtype), indices], axis=1) return tf.tensor_scatter_nd_update( scores, indices_2d, tf.fill(tf.shape(indices), -1.0)) # Note that since we're using 2D scores (not yet flattened for simplicity) # we use unadjusted `subj_index` and `obj_index`. scores = set_negative_scores(scores, subj_index) scores = set_negative_scores(scores, obj_index) # There are `2 * bsz` target mentions which were already chosen num_to_sample = tf.maximum(max_batch_mentions - 2 * bsz, 0) sampled_scores, sampled_indices = tf.math.top_k(tf.reshape( scores, [-1]), num_to_sample, sorted=True) # Note that negative scores indicate that we have double-sampled some of # the target mentions (we set their scores to negative right above). # In this case, we remove them. num_not_double_sampled = tf.reduce_sum( tf.cast(tf.not_equal(sampled_scores, -1), tf.int32)) sampled_indices = sampled_indices[:num_not_double_sampled] # Combine target mentions (subject / object) with sampled mentions mention_target_indices = tf.cast(mention_target_indices, sampled_indices.dtype) sampled_indices = tf.concat( [mention_target_indices, sampled_indices], axis=0) sampled_indices = mention_preprocess_utils.dynamic_padding_1d( sampled_indices, max_batch_mentions) dtype = batch['mention_start_positions'].dtype mention_mask = tf.reshape(batch['mention_mask'], [n_candidate_mentions]) new_batch['mention_mask'] = tf.gather(mention_mask, sampled_indices) new_batch['mention_start_positions'] = tf.gather( tf.reshape(batch['mention_start_positions'], [n_candidate_mentions]), sampled_indices) new_batch['mention_end_positions'] = tf.gather( tf.reshape(batch['mention_end_positions'], [n_candidate_mentions]), sampled_indices) new_batch['mention_batch_positions'] = tf.gather( tf.repeat(tf.range(bsz, dtype=dtype), config.max_mentions_per_sample), sampled_indices) new_batch['mention_target_indices'] = tf.range(2 * bsz, dtype=dtype) new_batch['mention_subject_indices'] = tf.range(bsz, dtype=dtype) * 2 new_batch['mention_object_indices'] = tf.range(bsz, dtype=dtype) * 2 + 1 if config.get('max_length_with_entity_tokens') is not None: batch_with_entity_tokens = mention_preprocess_utils.add_entity_tokens( text_ids=new_batch['text_ids'], text_mask=new_batch['text_mask'], mention_mask=new_batch['mention_mask'], mention_batch_positions=new_batch[ 'mention_batch_positions'], mention_start_positions=new_batch[ 'mention_start_positions'], mention_end_positions=new_batch['mention_end_positions'], new_length=config.max_length_with_entity_tokens, ) # Update `text_ids`, `text_mask`, `mention_mask`, `mention_*_positions` new_batch.update(batch_with_entity_tokens) # Update `max_length` max_length = config.max_length_with_entity_tokens else: max_length = encoder_config.max_length new_batch['mention_target_batch_positions'] = tf.gather( new_batch['mention_batch_positions'], new_batch['mention_target_indices']) new_batch['mention_target_start_positions'] = tf.gather( new_batch['mention_start_positions'], new_batch['mention_target_indices']) new_batch['mention_target_end_positions'] = tf.gather( new_batch['mention_end_positions'], new_batch['mention_target_indices']) new_batch['mention_target_weights'] = tf.ones(2 * bsz) # Fake IDs -- some encoders (ReadTwice) need them new_batch['mention_target_ids'] = tf.zeros(2 * bsz) new_batch['segment_ids'] = tf.zeros_like(new_batch['text_ids']) position_ids = tf.expand_dims(tf.range(max_length), axis=0) new_batch['position_ids'] = tf.tile(position_ids, (bsz, 1)) return new_batch
def _event_shape_tensor(self): dimension = self.scale_operator.domain_dimension_tensor() return tf.stack([dimension, dimension])
def concatenate_batch_into_sample(batch): for feature in batch.keys(): batch[feature] = tf.reshape(batch[feature], [1, -1]) return batch for batch in dataset: concatenated_examples.append(concatenate_batch_into_sample(batch)) feature_dict = {} for feature in concatenated_examples[0].keys(): feature_list = [example[feature] for example in concatenated_examples] feature_dict[feature] = tf.squeeze(tf.stack( feature_list, axis=0)) feature_dict["f0_hz"] = feature_dict["f0_hz"].numpy() if INTONATION: for di in range(feature_dict["f0_hz"].shape[0]): feature_dict["f0_hz"][di, :] = intonate( feature_dict["f0_hz"][di, :]) dataset = tf.data.Dataset.from_tensor_slices(feature_dict) ex = next(iter(dataset)) assert ex["audio"].shape[0] == 16000*16
def call(self, inputs): if not isinstance(inputs, (list, tuple)): raise ValueError( 'A merge layer should be called on a list of inputs. ' f'Received: inputs={inputs} (not a list of tensors)') if self._reshape_required: reshaped_inputs = [] input_ndims = list(map(backend.ndim, inputs)) if None not in input_ndims: # If ranks of all inputs are available, # we simply expand each of them at axis=1 # until all of them have the same rank. max_ndim = max(input_ndims) for x in inputs: x_ndim = backend.ndim(x) for _ in range(max_ndim - x_ndim): x = tf.expand_dims(x, axis=1) reshaped_inputs.append(x) return self._merge_function(reshaped_inputs) else: # Transpose all inputs so that batch size is the last dimension. # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size) transposed = False for x in inputs: x_ndim = backend.ndim(x) if x_ndim is None: x_shape = tf.shape(x) batch_size = x_shape[0] new_shape = backend.concatenate( [x_shape[1:], tf.expand_dims(batch_size, axis=-1)]) x_transposed = tf.reshape( x, tf.stack([batch_size, tf.reduce_prod(x_shape[1:])], axis=0)) x_transposed = tf.transpose(x_transposed, perm=(1, 0)) x_transposed = tf.reshape(x_transposed, new_shape) reshaped_inputs.append(x_transposed) transposed = True elif x_ndim > 1: dims = list(range(1, x_ndim)) + [0] reshaped_inputs.append(tf.transpose(x, perm=dims)) transposed = True else: # We don't transpose inputs if they are 1D vectors or scalars. reshaped_inputs.append(x) y = self._merge_function(reshaped_inputs) y_ndim = backend.ndim(y) if transposed: # If inputs have been transposed, we have to transpose the output too. if y_ndim is None: y_shape = tf.shape(y) y_ndim = tf.shape(y_shape)[0] batch_size = y_shape[y_ndim - 1] new_shape = backend.concatenate([ tf.expand_dims(batch_size, axis=-1), y_shape[:y_ndim - 1] ]) y = tf.reshape(y, (-1, batch_size)) y = tf.transpose(y, perm=(1, 0)) y = tf.reshape(y, new_shape) elif y_ndim > 1: dims = [y_ndim - 1] + list(range(y_ndim - 1)) y = tf.transpose(y, perm=dims) return y else: return self._merge_function(inputs)
def solve_nu_zeta(self, dataset: dataset_lib.OffpolicyDataset, target_policy: tf_policy.TFPolicy, regularizer: float = 1e-6): """Solves for density ratios and then approximates target policy value. Args: dataset: The dataset to sample experience from. target_policy: The policy whose value we want to estimate. regularizer: A small constant to add to matrices before inverting them or to floats before taking square root. Returns: Estimated average per-step reward of the target policy. """ if not hasattr(self, '_td_mat'): # Set up env_steps. episodes, valid_steps = dataset.get_all_episodes( limit=self._limit_episodes) total_num_steps_per_episode = tf.shape(valid_steps)[1] - 1 num_episodes = tf.shape(valid_steps)[0] num_samples = num_episodes * total_num_steps_per_episode valid_and_not_last = tf.logical_and(valid_steps, episodes.discount > 0) valid_indices = tf.squeeze( tf.where(tf.reshape(valid_and_not_last[:, :-1], [-1]))) initial_env_step = tf.nest.map_structure( lambda t: tf.squeeze( tf.reshape( tf.repeat( t[:, 0:1, ...], axis=1, repeats=total_num_steps_per_episode), [num_samples, -1])), episodes) initial_env_step = tf.nest.map_structure( lambda t: tf.gather(t, valid_indices), initial_env_step) tfagents_initial_env_step = dataset_lib.convert_to_tfagents_timestep( initial_env_step) env_step = tf.nest.map_structure( lambda t: tf.squeeze( tf.reshape(t[:, 0:total_num_steps_per_episode, ...], [num_samples, -1])), episodes) env_step = tf.nest.map_structure(lambda t: tf.gather(t, valid_indices), env_step) tfagents_env_step = dataset_lib.convert_to_tfagents_timestep(env_step) next_env_step = tf.nest.map_structure( lambda t: tf.squeeze( tf.reshape(t[:, 1:total_num_steps_per_episode + 1, ...], [num_samples, -1])), episodes) next_env_step = tf.nest.map_structure( lambda t: tf.gather(t, valid_indices), next_env_step) tfagents_next_env_step = dataset_lib.convert_to_tfagents_timestep( next_env_step) # get probabilities initial_target_probs = target_policy.distribution( tfagents_initial_env_step).action.probs_parameter() next_target_probs = target_policy.distribution( tfagents_next_env_step).action.probs_parameter() # First, get the nu_loss and data weights #current_nu_loss = self._get_nu_loss(initial_env_step, env_step, # next_env_step, target_policy) #data_weight, _ = self._get_weights(current_nu_loss) # # debug only and to reproduce dual dice result, DELETE # data_weight = tf.ones_like(data_weight) state_action_count = self._get_state_action_counts(env_step) counts = tf.reduce_sum(tf.one_hot(state_action_count, self._dimension), 0) gamma_sample = tf.pow(self._gamma, tf.cast(env_step.step_num, tf.float32)) # # debug only and to reproduce dual dice result, DELETE # gamma_sample = tf.ones_like(gamma_sample) # now we need to expand_dims to include action space in extra dimensions #data_weights = tf.reshape(data_weight, [-1, self._num_limits]) # both are data sample weights for L2 problem, needs to be normalized later #gamma_data_weights = tf.reshape(gamma_sample, [-1, 1]) * data_weights initial_states = tf.tile( tf.reshape(initial_env_step.observation, [-1, 1]), [1, self._num_actions]) initial_actions = tf.tile( tf.reshape(tf.range(self._num_actions), [1, -1]), [initial_env_step.observation.shape[0], 1]) initial_nu_indices = self._get_index(initial_states, initial_actions) # linear term w.r.t. initial distribution #b_vec_2 = tf.stack([ # tf.reduce_sum( # tf.reshape( # data_weights[:, itr] / tf.reduce_sum(data_weights[:, itr]), # [-1, 1]) * tf.reduce_sum( # tf.one_hot(initial_nu_indices, self._dimension) * # (1 - self._gamma) * # tf.expand_dims(initial_target_probs, axis=-1), # axis=1), # axis=0) for itr in range(self._num_limits) #], # axis=0) next_states = tf.tile( tf.reshape(next_env_step.observation, [-1, 1]), [1, self._num_actions]) next_actions = tf.tile( tf.reshape(tf.range(self._num_actions), [1, -1]), [next_env_step.observation.shape[0], 1]) next_nu_indices = self._get_index(next_states, next_actions) next_nu_indices = tf.where( tf.expand_dims(next_env_step.is_absorbing(), -1), -1 * tf.ones_like(next_nu_indices), next_nu_indices) nu_indices = self._get_index(env_step.observation, env_step.action) target_log_probabilities = target_policy.distribution( tfagents_env_step).action.log_prob(env_step.action) if not self._solve_for_state_action_ratio: policy_ratio = tf.exp(target_log_probabilities - env_step.get_log_probability()) else: policy_ratio = tf.ones([ target_log_probabilities.shape[0], ]) policy_ratios = tf.tile( tf.reshape(policy_ratio, [-1, 1]), [1, self._num_actions]) # the tabular feature vector a_vec = tf.one_hot(nu_indices, self._dimension) - tf.reduce_sum( self._gamma * tf.expand_dims(next_target_probs * policy_ratios, axis=-1) * tf.one_hot(next_nu_indices, self._dimension), axis=1) # linear term w.r.t. reward #b_vec_1 = tf.stack([ # tf.reduce_sum( # tf.reshape( # (gamma_data_weights[:, itr] / # tf.reduce_sum(gamma_data_weights[:, itr])) * self._reward_fn(env_step), #/ # #tf.cast(state_action_count, tf.float32), # [-1, 1]) * a_vec, # axis=0) for itr in range(self._num_limits) #], # axis=0) # quadratic term of feature # Get weighted outer product by using einsum to save computing resource! #a_mat = tf.stack([ # tf.einsum( # 'ai, a, aj -> ij', a_vec, # #1.0 / tf.cast(state_action_count, tf.float32), # gamma_data_weights[:, itr] / # tf.reduce_sum(gamma_data_weights[:, itr]), # a_vec) # for itr in range(self._num_limits) #], # axis=0) td_mat = tf.einsum('ai, a, aj -> ij', tf.one_hot(nu_indices, self._dimension), 1.0 / tf.cast(state_action_count, tf.float32), a_vec) weighted_rewards = policy_ratio * self._reward_fn(env_step) bias = tf.reduce_sum( tf.one_hot(nu_indices, self._dimension) * tf.reshape(weighted_rewards, [-1, 1]) * 1.0 / tf.cast(state_action_count, tf.float32)[:, None], axis=0) # Initialize self._nu = np.ones_like(self._nu) * bias[:, None] self._nu2 = np.ones_like(self._nu2) * bias[:, None] self._a_vec = a_vec self._td_mat = td_mat self._bias = bias self._weighted_rewards = weighted_rewards self._state_action_count = state_action_count self._nu_indices = nu_indices self._initial_nu_indices = initial_nu_indices self._initial_target_probs = initial_target_probs self._gamma_sample = gamma_sample self._gamma_sample = tf.ones_like(gamma_sample) saddle_bellman_residuals = ( tf.matmul(self._a_vec, self._nu) - self._weighted_rewards[:, None]) saddle_bellman_residuals *= -1 * self._algae_alpha_sign saddle_zetas = tf.gather(self._zeta, self._nu_indices) saddle_initial_nu_values = tf.reduce_sum( # Average over actions. self._initial_target_probs[:, :, None] * tf.gather(self._nu, self._initial_nu_indices), axis=1) saddle_init_nu_loss = ((1 - self._gamma) * saddle_initial_nu_values * self._algae_alpha_sign) saddle_bellman_residuals2 = ( tf.matmul(self._a_vec, self._nu2) - self._weighted_rewards[:, None]) saddle_bellman_residuals2 *= 1 * self._algae_alpha_sign saddle_zetas2 = tf.gather(self._zeta2, self._nu_indices) saddle_initial_nu_values2 = tf.reduce_sum( # Average over actions. self._initial_target_probs[:, :, None] * tf.gather(self._nu2, self._initial_nu_indices), axis=1) saddle_init_nu_loss2 = ((1 - self._gamma) * saddle_initial_nu_values2 * -1 * self._algae_alpha_sign) saddle_loss = 0.5 * ( saddle_init_nu_loss + saddle_bellman_residuals * saddle_zetas + -tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas) + -saddle_init_nu_loss2 + -saddle_bellman_residuals2 * saddle_zetas2 + tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas2)) # Binary search to find best alpha. left = tf.constant([-8., -8.]) right = tf.constant([32., 32.]) for _ in range(16): mid = 0.5 * (left + right) self._alpha.assign(mid) weights, log_weights = self._get_weights(saddle_loss * self._gamma_sample[:, None]) divergence = self._compute_divergence(weights, log_weights) divergence_violation = divergence - self._two_sided_limit left = tf.where(divergence_violation > 0., mid, left) right = tf.where(divergence_violation > 0., right, mid) self._alpha.assign(0.5 * (left + right)) weights, log_weights = self._get_weights(saddle_loss * self._gamma_sample[:, None]) gamma_data_weights = tf.stop_gradient(weights * self._gamma_sample[:, None]) #print(tf.concat([gamma_data_weights, saddle_loss], axis=-1)) avg_saddle_loss = ( tf.reduce_sum(gamma_data_weights * saddle_loss, axis=0) / tf.reduce_sum(gamma_data_weights, axis=0)) weighted_state_action_count = tf.reduce_sum( tf.one_hot(self._nu_indices, self._dimension)[:, :, None] * weights[:, None, :], axis=0) weighted_state_action_count = tf.gather(weighted_state_action_count, self._nu_indices) my_td_mat = tf.einsum( 'ai, ab, ab, aj -> bij', tf.one_hot(self._nu_indices, self._dimension), #1.0 / tf.cast(self._state_action_count, tf.float32), 1.0 / weighted_state_action_count, weights, self._a_vec) my_bias = tf.reduce_sum( tf.transpose(weights)[:, :, None] * tf.one_hot(self._nu_indices, self._dimension)[None, :, :] * tf.reshape(self._weighted_rewards, [1, -1, 1]) * #1.0 / tf.cast(self._state_action_count, tf.float32)[None, :, None], 1.0 / tf.transpose(weighted_state_action_count)[:, :, None], axis=1) #print('hello', saddle_initial_nu_values[:1], saddle_zetas[:3], # self._nu[:2], my_bias[:, :2], saddle_loss[:4]) with tf.GradientTape( watch_accessed_variables=False, persistent=True) as tape: tape.watch([self._nu, self._nu2, self._alpha]) bellman_residuals = tf.matmul( my_td_mat, tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None] bellman_residuals = tf.transpose(tf.squeeze(bellman_residuals, -1)) bellman_residuals = tf.gather(bellman_residuals, self._nu_indices) initial_nu_values = tf.reduce_sum( # Average over actions. self._initial_target_probs[:, :, None] * tf.gather(self._nu, self._initial_nu_indices), axis=1) bellman_residuals *= self._algae_alpha_sign init_nu_loss = ((1 - self._gamma) * initial_nu_values * self._algae_alpha_sign) nu_loss = ( tf.math.square(bellman_residuals) / 2.0 + tf.math.abs(self._algae_alpha) * init_nu_loss) loss = ( gamma_data_weights * nu_loss / tf.reduce_sum(gamma_data_weights, axis=0, keepdims=True)) bellman_residuals2 = tf.matmul( my_td_mat, tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :, None] bellman_residuals2 = tf.transpose(tf.squeeze(bellman_residuals2, -1)) bellman_residuals2 = tf.gather(bellman_residuals2, self._nu_indices) initial_nu_values2 = tf.reduce_sum( # Average over actions. self._initial_target_probs[:, :, None] * tf.gather(self._nu2, self._initial_nu_indices), axis=1) bellman_residuals2 *= -1 * self._algae_alpha_sign init_nu_loss2 = ((1 - self._gamma) * initial_nu_values2 * -1 * self._algae_alpha_sign) nu_loss2 = ( tf.math.square(bellman_residuals2) / 2.0 + tf.math.abs(self._algae_alpha) * init_nu_loss2) loss2 = ( gamma_data_weights * nu_loss2 / tf.reduce_sum(gamma_data_weights, axis=0, keepdims=True)) divergence = self._compute_divergence(weights, log_weights) divergence_violation = divergence - self._two_sided_limit alpha_loss = (-tf.exp(self._alpha) * tf.stop_gradient(divergence_violation)) extra_loss = tf.reduce_sum(tf.math.square(self._nu[-1, :])) extra_loss2 = tf.reduce_sum(tf.math.square(self._nu2[-1, :])) nu_grad = tape.gradient(loss + extra_loss, [self._nu])[0] nu_grad2 = tape.gradient(loss2 + extra_loss2, [self._nu2])[0] avg_loss = tf.reduce_sum( 0.5 * (loss - loss2) / tf.math.abs(self._algae_alpha), axis=0) nu_jacob = tape.jacobian(nu_grad, [self._nu])[0] nu_hess = tf.stack([nu_jacob[:, i, :, i] for i in range(self._num_limits)], axis=0) nu_jacob2 = tape.jacobian(nu_grad2, [self._nu2])[0] nu_hess2 = tf.stack( [nu_jacob2[:, i, :, i] for i in range(self._num_limits)], axis=0) for idx, div in enumerate(divergence): tf.summary.scalar('divergence%d' % idx, div) #alpha_grads = tape.gradient(alpha_loss, [self._alpha]) #alpha_grad_op = self._alpha_optimizer.apply_gradients( # zip(alpha_grads, [self._alpha])) #self._alpha.assign(tf.minimum(8., tf.maximum(-8., self._alpha))) #print(self._alpha, tf.concat([weights, nu_loss], -1)) #regularizer = 0.1 nu_transformed = tf.transpose( tf.squeeze( tf.linalg.solve(nu_hess + regularizer * tf.eye(self._dimension), tf.expand_dims(-tf.transpose(nu_grad), axis=-1)))) self._nu = self._nu + 0.1 * nu_transformed nu_transformed2 = tf.transpose( tf.squeeze( tf.linalg.solve(nu_hess2 + regularizer * tf.eye(self._dimension), tf.expand_dims(-tf.transpose(nu_grad2), axis=-1)))) self._nu2 = self._nu2 + 0.1 * nu_transformed2 print(avg_loss * self._algae_alpha_sign, avg_saddle_loss * self._algae_alpha_sign, self._nu[:2], divergence) #print(init_nu_loss[:8], init_nu_loss[-8:]) #print(bellman_residuals[:8]) #print(self._nu[:3], self._zeta[:3]) zetas = tf.matmul(my_td_mat, tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None] zetas = tf.transpose(tf.squeeze(zetas, -1)) zetas *= -self._algae_alpha_sign zetas /= tf.math.abs(self._algae_alpha) self._zeta = self._zeta + 0.1 * (zetas - self._zeta) zetas2 = tf.matmul(my_td_mat, tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :, None] zetas2 = tf.transpose(tf.squeeze(zetas2, -1)) zetas2 *= 1 * self._algae_alpha_sign zetas2 /= tf.math.abs(self._algae_alpha) self._zeta2 = self._zeta2 + 0.1 * (zetas2 - self._zeta2) #self._zeta = ( # tf.einsum('ij,ja-> ia', self._td_mat, self._nu) - # tf.transpose(my_bias)) #self._zeta *= -tf.reshape(self._algae_alpha_sign, [1, self._num_limits]) #self._zeta /= tf.math.abs(self._algae_alpha) return [ avg_saddle_loss * self._algae_alpha_sign, avg_loss * self._algae_alpha_sign, divergence ]
def per_replica_to_tensor(value): return tf.nest.map_structure( lambda per_replica: tf.stack(per_replica.values, axis=0), value)
def barrier_price(*, volatilities: types.RealTensor, strikes: types.RealTensor, expiries: types.RealTensor, spots: types.RealTensor, barriers: types.RealTensor, rebates: types.RealTensor = None, discount_rates: types.RealTensor = None, dividend_rates: types.RealTensor = None, is_barrier_down: types.BoolTensor = None, is_knock_out: types.BoolTensor = None, is_call_options: types.BoolTensor = None, dtype: tf.DType = None, name: str = None) -> types.RealTensor: """Prices barrier options in a Black-Scholes Model. Computes the prices of options with a single barrier in Black-Scholes world as described in Ref. [1]. Note that the barrier is applied continuously. #### Example This example is taken from Ref. [2], Page 154. ```python import tf_quant_finance as tff dtype = np.float32 discount_rates = np.array([.08, .08]) dividend_rates = np.array([.04, .04]) spots = np.array([100., 100.]) strikes = np.array([90., 90.]) barriers = np.array([95. 95.]) rebates = np.array([3. 3.]) volatilities = np.array([.25, .25]) expiries = np.array([.5, .5]) barriers_type = np.array([5, 1]) is_barrier_down = np.array([True, False]) is_knock_out = np.array([False, False]) is_call_option = np.array([True, True]) price = tff.black_scholes.barrier_price( discount_rates, dividend_rates, spots, strikes, barriers, rebates, volatilities, expiries, is_barrier_down, is_knock_out, is_call_options) # Expected output # `Tensor` with values [9.024, 7.7627] ``` #### References [1]: Lee Clewlow, Javier Llanos, Chris Strickland, Caracas Venezuela Pricing Exotic Options in a Black-Scholes World, 1994 https://warwick.ac.uk/fac/soc/wbs/subjects/finance/research/wpaperseries/1994/94-54.pdf [2]: Espen Gaarder Haug, The Complete Guide to Option Pricing Formulas, 2nd Edition, 1997 Args: volatilities: Real `Tensor` of any shape and dtype. The volatilities to expiry of the options to price. strikes: A real `Tensor` of the same dtype and compatible shape as `volatilities`. The strikes of the options to be priced. expiries: A real `Tensor` of same dtype and compatible shape as `volatilities`. The expiry of each option. The units should be such that `expiry * volatility**2` is dimensionless. spots: A real `Tensor` of any shape that broadcasts to the shape of the `volatilities`. The current spot price of the underlying. barriers: A real `Tensor` of same dtype as the `volatilities` and of the shape that broadcasts with `volatilities`. The barriers of each option. rebates: A real `Tensor` of same dtype as the `volatilities` and of the shape that broadcasts with `volatilities`. For knockouts, this is a fixed cash payout in case the barrier is breached. For knockins, this is a fixed cash payout in case the barrier level is not breached. In the former case, the rebate is paid immediately on breach whereas in the latter, the rebate is paid at the expiry of the option. Default value: `None` which maps to no rebates. discount_rates: A real `Tensor` of same dtype as the `volatilities` and of the shape that broadcasts with `volatilities`. Discount rates, or risk free rates. Default value: `None`, equivalent to discount_rate = 0. dividend_rates: A real `Tensor` of same dtype as the `volatilities` and of the shape that broadcasts with `volatilities`. A continuous dividend rate paid by the underlier. If `None`, then defaults to zero dividends. Default value: `None`, equivalent to zero dividends. is_barrier_down: A real `Tensor` of `boolean` values and of the shape that broadcasts with `volatilities`. True if barrier is below asset price at expiration. Default value: `True`. is_knock_out: A real `Tensor` of `boolean` values and of the shape that broadcasts with `volatilities`. True if option is knock out else false. Default value: `True`. is_call_options: A real `Tensor` of `boolean` values and of the shape that broadcasts with `volatilities`. True if option is call else false. Default value: `True`. dtype: Optional `tf.DType`. If supplied, the dtype to be used for conversion of any supplied non-`Tensor` arguments to `Tensor`. Default value: `None` which maps to the default dtype inferred by TensorFlow. name: str. The name for the ops created by this function. Default value: `None` which is mapped to the default name `barrier_price`. Returns: option_prices: A `Tensor` of same shape as `spots`. The approximate price of the barriers option under black scholes. """ # The computation is done as in Ref [2] where each integral is split into # two matrices. The first matrix contains the algebraic terms and the second # matrix contains the probability distribution terms. Masks are used to filter # appropriate terms for calculating the integral. Then a dot product of each # row in the matricies coupled with the masks work to calculate the prices of # the barriers option. with tf.name_scope(name or 'barrier_price'): spots = tf.convert_to_tensor(spots, dtype=dtype, name='spots') dtype = spots.dtype strikes = tf.convert_to_tensor(strikes, dtype=dtype, name='strikes') volatilities = tf.convert_to_tensor( volatilities, dtype=dtype, name='volatilities') expiries = tf.convert_to_tensor(expiries, dtype=dtype, name='expiries') barriers = tf.convert_to_tensor(barriers, dtype=dtype, name='barriers') if rebates is not None: rebates = tf.convert_to_tensor(rebates, dtype=dtype, name='rebates') else: rebates = tf.zeros_like(spots, dtype=dtype, name='rebates') # Convert all to tensor and enforce float dtype where required if discount_rates is not None: discount_rates = tf.convert_to_tensor( discount_rates, dtype=dtype, name='discount_rates') else: discount_rates = tf.convert_to_tensor( 0.0, dtype=dtype, name='discount_rates') if dividend_rates is not None: dividend_rates = tf.convert_to_tensor( dividend_rates, dtype=dtype, name='dividend_rates') else: dividend_rates = tf.convert_to_tensor( 0.0, dtype=dtype, name='dividend_rates') if is_barrier_down is None: is_barrier_down = tf.constant(1, name='is_barrier_down') else: is_barrier_down = tf.convert_to_tensor(is_barrier_down, dtype=tf.bool, name='is_barrier_down') is_barrier_down = tf.where(is_barrier_down, 1, 0) if is_knock_out is None: is_knock_out = tf.constant(1, name='is_knock_out') else: is_knock_out = tf.convert_to_tensor(is_knock_out, dtype=tf.bool, name='is_knock_out') is_knock_out = tf.where(is_knock_out, 1, 0) if is_call_options is None: is_call_options = tf.constant(1, name='is_call_options') else: is_call_options = tf.convert_to_tensor(is_call_options, dtype=tf.bool, name='is_call_options') is_call_options = tf.where(is_call_options, 1, 0) # Indices which range from 0-7 are used to select the appropriate # mask for each barrier indices = tf.bitwise.left_shift( is_barrier_down, 2) + tf.bitwise.left_shift( is_knock_out, 1) + is_call_options # Masks select the appropriate terms for integral approximations # Integrals are separated by algebraic terms and probability # distribution terms. This give 12 different terms per matrix # (6 integrals, 2 terms each) # shape = [8, 12] mask_matrix_greater_strike = tf.constant([ [1, 1, -1, -1, 0, 0, 1, 1, 1, 1, 0, 0], # up and in put [1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], # up and in call [0, 0, 1, 1, 0, 0, -1, -1, 0, 0, 1, 1], # up and out put [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], # up and out call [0, 0, 1, 1, -1, -1, 1, 1, 0, 0, 1, 1], # down and in put [0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0], # down and in call [1, 1, -1, -1, 1, 1, -1, -1, 0, 0, 1, 1], # down and out put [1, 1, 0, 0, -1, -1, 0, 0, 0, 0, 1, 1]]) # down and out call mask_matrix_lower_strike = tf.constant([ [0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0], # up and in put [0, 0, 1, 1, -1, -1, 1, 1, 1, 1, 0, 0], # up and in call [1, 1, 0, 0, -1, -1, 0, 0, 0, 0, 1, 1], # up and out put [1, 1, -1, -1, 1, 1, -1, -1, 0, 0, 1, 1], # up and out call [1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], # down and in put [1, 1, -1, -1, 0, 0, 1, 1, 1, 1, 0, 0], # down and in call [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], # down and out put [0, 0, 1, 1, 0, 0, -1, -1, 0, 0, 1, 1]]) # down and out call # Create masks # Masks are shape [strikes.shape, 12] masks_lower = tf.gather(mask_matrix_lower_strike, indices, axis=0) masks_greater = tf.gather(mask_matrix_greater_strike, indices, axis=0) strikes_greater = tf.expand_dims(strikes > barriers, axis=-1) masks = tf.where(strikes_greater, masks_greater, masks_lower) masks = tf.cast(masks, dtype=dtype) one = tf.constant(1, dtype=dtype) call_or_put = tf.cast(tf.where(tf.equal(is_call_options, 0), -one, one), dtype=dtype) below_or_above = tf.cast(tf.where(tf.equal(is_barrier_down, 0), -one, one), dtype=dtype) # Calculate params for integrals sqrt_var = volatilities * tf.math.sqrt(expiries) mu = (discount_rates - dividend_rates) - ((volatilities**2) / 2) lamda = 1 + (mu / (volatilities**2)) x = (tf.math.log(spots / strikes) / (sqrt_var)) + (lamda * sqrt_var) x1 = (tf.math.log(spots / barriers) / (sqrt_var)) + (lamda * sqrt_var) y = (tf.math.log((barriers**2) / (spots * strikes)) / ( sqrt_var)) + (lamda * sqrt_var) y1 = (tf.math.log(barriers / spots) / (sqrt_var)) + (lamda * sqrt_var) b = ((mu**2) + (2 * (volatilities**2) * discount_rates)) / (volatilities**2) z = (tf.math.log(barriers / spots) / (sqrt_var)) + (b * sqrt_var) a = mu / (volatilities**2) # Other params used for integrals discount_factors = tf.math.exp( -discount_rates * expiries, name='discount_factors') barriers_ratio = tf.math.divide(barriers, spots, name='barriers_ratio') spots_term = call_or_put * spots * tf.math.exp(-dividend_rates * expiries) strikes_term = call_or_put * strikes * discount_factors # rank is used to stack elements and reduce_sum strike_rank = strikes.shape.rank # Constructing Matrix with first and second algebraic terms for each # integral [strike.shape, 12] terms_mat = tf.stack( (spots_term, -strikes_term, spots_term, -strikes_term, spots_term * (barriers_ratio**(2 * lamda)), -strikes_term * (barriers_ratio**((2 * lamda) - 2)), spots_term * (barriers_ratio**(2 * lamda)), -strikes_term * (barriers_ratio**((2 * lamda) - 2)), rebates * discount_factors, -rebates * discount_factors * ( # pylint: disable=invalid-unary-operand-type barriers_ratio**((2 * lamda) - 2)), rebates * (barriers_ratio**(a + b)), rebates * (barriers_ratio**(a - b))), name='term_matrix', axis=strike_rank) # Constructing Matrix with first and second norm for each integral # [strikes.shape, 12] cdf_mat = tf.stack( (call_or_put * x, call_or_put * (x - sqrt_var), call_or_put * x1, call_or_put * (x1 - sqrt_var), below_or_above * y, below_or_above * (y - sqrt_var), below_or_above * y1, below_or_above * (y1 - sqrt_var), below_or_above * (x1 - sqrt_var), below_or_above * (y1 - sqrt_var), below_or_above * z, below_or_above * (z - (2 * b * sqrt_var))), name='cdf_matrix', axis=strike_rank) cdf_mat = _ncdf(cdf_mat) # Calculating and returning price for each option return tf.reduce_sum(masks * terms_mat * cdf_mat, axis=strike_rank)
def _loop_tree_doubling(self, step_size, momentum_state_memory, current_step_meta_info, iter_, initial_step_state, initial_step_metastate, seed): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): (direction_seed, subtree_seed, acceptance_seed, next_seed) = samplers.split_seed(seed, n=4) batch_shape = ps.shape(current_step_meta_info.init_energy) direction = tf.cast(samplers.uniform(shape=batch_shape, minval=0, maxval=2, dtype=tf.int32, seed=direction_seed), dtype=tf.bool) tree_start_states = tf.nest.map_structure( lambda v: bu.where_left_justified_mask(direction, v[1], v[0]), initial_step_state) directions_expanded = [ bu.left_justified_expand_dims_like(direction, state) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(d, ss, -ss) for d, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, momentum_subtree_cumsum, leapfrogs_taken ] = self._build_sub_tree( directions_expanded, integrator, current_step_meta_info, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory, seed=subtree_seed) last_candidate_state = initial_step_metastate.candidate_state energy_diff_sum = (energy_diff_tree_sum + initial_step_metastate.energy_diff_sum) if MULTINOMIAL_SAMPLE: tree_weight = tf.where( continue_tree_final, candidate_tree_state.weight, tf.constant(-np.inf, dtype=candidate_tree_state.weight.dtype)) weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: tree_weight = tf.where(continue_tree_final, candidate_tree_state.weight, tf.zeros([], dtype=TREE_COUNT_DTYPE)) weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-samplers.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=acceptance_seed)) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ bu.where_left_justified_mask(choose_new_state, s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=bu.where_left_justified_mask( choose_new_state, candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ bu.where_left_justified_mask(choose_new_state, grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], energy=bu.where_left_justified_mask( choose_new_state, candidate_tree_state.energy, last_candidate_state.energy), weight=weight_sum) for new_candidate_state_temp, old_candidate_state_temp in zip( new_candidate_state.state, last_candidate_state.state): tensorshape_util.set_shape(new_candidate_state_temp, old_candidate_state_temp.shape) for new_candidate_grad_temp, old_candidate_grad_temp in zip( new_candidate_state.target_grad_parts, last_candidate_state.target_grad_parts): tensorshape_util.set_shape(new_candidate_grad_temp, old_candidate_grad_temp.shape) # Update left right information of the trajectory, and check trajectory # level U turn tree_otherend_states = tf.nest.map_structure( lambda v: bu.where_left_justified_mask(direction, v[0], v[1]), initial_step_state) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ tf.stack( [ # pylint: disable=g-complex-comprehension bu.where_left_justified_mask( direction, right, left), bu.where_left_justified_mask( direction, left, right), ], axis=0) for left, right in zip( tf.nest.flatten(tree_final_states), tf.nest.flatten(tree_otherend_states)) ]) momentum_tree_cumsum = [] for p0, p1 in zip(initial_step_metastate.momentum_sum, momentum_subtree_cumsum): momentum_part_temp = p0 + p1 tensorshape_util.set_shape(momentum_part_temp, p0.shape) momentum_tree_cumsum.append(momentum_part_temp) for new_state_temp, old_state_temp in zip( tf.nest.flatten(new_step_state), tf.nest.flatten(initial_step_state)): tensorshape_util.set_shape(new_state_temp, old_state_temp.shape) if GENERALIZED_UTURN: state_diff = momentum_tree_cumsum else: state_diff = [s[1] - s[0] for s in new_step_state.state] no_u_turns_trajectory = has_not_u_turn( state_diff, [m[0] for m in new_step_state.momentum], [m[1] for m in new_step_state.momentum], log_prob_rank=ps.rank_from_shape(batch_shape), shard_axis_names=self.experimental_shard_axis_names) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, momentum_sum=momentum_tree_cumsum, energy_diff_sum=energy_diff_sum, continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, next_seed, new_step_state, new_step_metastate
def observation_jacobian_fn_3dim(x): return tf.reshape( tf.stack([1., 0., 1., 1., x[..., 1], x[..., 0]], axis=-1), [3, 2])
def convert_sharded_tensor_to_eager_tensor(value, *args, **kwargs): del args, kwargs # TODO(nareshmodi): Consider a collective op to gather the tensors from the # various devices for performance reasons. return tf.stack(value.tensors)