Example #1
0
def regression_loss(logits, labels, num_steps, steps, seq_lens, loss_type,
                    normalize_indices, variance_lambda, huber_delta):
    """Loss function based on regressing to the correct indices.

  In the paper, this is called Cycle-back Regression. There are 3 variants
  of this loss:
  i) regression_mse: MSE of the predicted indices and ground truth indices.
  ii) regression_mse_var: MSE of the predicted indices that takes into account
  the variance of the similarities. This is important when the rate at which
  sequences go through different phases changes a lot. The variance scaling
  allows dynamic weighting of the MSE loss based on the similarities.
  iii) regression_huber: Huber loss between the predicted indices and ground
  truth indices.


  Args:
    logits: Tensor, Pre-softmax similarity scores after cycling back to the
      starting sequence.
    labels: Tensor, One hot labels containing the ground truth. The index where
      the cycle started is 1.
    num_steps: Integer, Number of steps in the sequence embeddings.
    steps: Tensor, step indices/frame indices of the embeddings of the shape
      [N, T] where N is the batch size, T is the number of the timesteps.
    seq_lens: Tensor, Lengths of the sequences from which the sampling was done.
      This can provide additional temporal information to the alignment loss.
    loss_type: String, This specifies the kind of regression loss function.
      Currently supported loss functions: regression_mse, regression_mse_var,
      regression_huber.
    normalize_indices: Boolean, If True, normalizes indices by sequence lengths.
      Useful for ensuring numerical instabilities don't arise as sequence
      indices can be large numbers.
    variance_lambda: Float, Weight of the variance of the similarity
      predictions while cycling back. If this is high then the low variance
      similarities are preferred by the loss while making this term low results
      in high variance of the similarities (more uniform/random matching).
    huber_delta: float, Huber delta described in tf.keras.losses.huber_loss.

  Returns:
     loss: Tensor, A scalar loss calculated using a variant of regression.
  """
    # Just to be safe, we stop gradients from labels as we are generating labels.
    labels = tf.stop_gradient(labels)
    steps = tf.stop_gradient(steps)

    if normalize_indices:
        float_seq_lens = tf.cast(seq_lens, tf.float32)
        tile_seq_lens = tf.tile(tf.expand_dims(float_seq_lens, axis=1),
                                [1, num_steps])
        steps = tf.cast(steps, tf.float32) / tile_seq_lens
    else:
        steps = tf.cast(steps, tf.float32)

    beta = tf.nn.softmax(logits)
    true_time = tf.reduce_sum(steps * labels, axis=1)
    pred_time = tf.reduce_sum(steps * beta, axis=1)

    if loss_type in ['regression_mse', 'regression_mse_var']:
        if 'var' in loss_type:
            # Variance aware regression.
            pred_time_tiled = tf.tile(tf.expand_dims(pred_time, axis=1),
                                      [1, num_steps])

            pred_time_variance = tf.reduce_sum(
                tf.square(steps - pred_time_tiled) * beta, axis=1)

            # Using log of variance as it is numerically stabler.
            pred_time_log_var = tf.math.log(pred_time_variance)
            squared_error = tf.square(true_time - pred_time)
            return tf.reduce_mean(
                tf.math.exp(-pred_time_log_var) * squared_error +
                variance_lambda * pred_time_log_var)

        else:
            return tf.reduce_mean(
                tf.keras.losses.mean_squared_error(y_true=true_time,
                                                   y_pred=pred_time))
    elif loss_type == 'regression_huber':
        return tf.reduce_mean(
            tf.keras.losses.huber_loss(y_true=true_time,
                                       y_pred=pred_time,
                                       delta=huber_delta))
    else:
        raise ValueError(
            'Unsupported regression loss %s. Supported losses are: '
            'regression_mse, regresstion_mse_var and regression_huber.' %
            loss_type)
Example #2
0
 def _variance(self):
     rate_sq = (tf.math.exp(self.log_rate * 2)
                if self.rate is None else tf.square(self.rate))
     return self.concentration / rate_sq
Example #3
0
 def _variance(self):
     return tf.square(self.range()) / 12.
Example #4
0
def _norm_sq(x):
  """Evaluates L2 norm squared."""
  return tf.reduce_sum(tf.square(x), axis=-1)
Example #5
0
 def _stddev(self):
     r = self.samples - tf.expand_dims(self.mean(), axis=self._samples_axis)
     var = tf.reduce_mean(tf.square(r), axis=self._samples_axis)
     return tf.sqrt(var)
Example #6
0
 def call(self, inputs):
     return tf.square(inputs)
Example #7
0
def loss_fn(params, inputs, targets):
    predicted = params[0] * inputs + params[1]
    loss = tf.reduce_mean(input_tensor=tf.square(predicted - targets))
    return tf_np.asarray(loss)
Example #8
0
 def _forward_log_det_jacobian(self, x):
     return -0.5 * np.log(2 * np.pi) - tf.square(x) / 2.
Example #9
0
 def _variance(self):
   # TODO(b/72696533): Investigate a more numerically stable version.
   return self._moment(2) - tf.square(self._moment(1))
Example #10
0
 def _log_prob(self, x):
   npdt = dtype_util.as_numpy_dtype(self.dtype)
   scale = tf.convert_to_tensor(self.scale)
   log_unnormalized_prob = -tf.math.log1p(tf.square(self._z(x, scale=scale)))
   log_normalization = npdt(np.log(np.pi)) + tf.math.log(scale)
   return log_unnormalized_prob - log_normalization
def brier_decomposition(labels, logits, name=None):
    r"""Decompose the Brier score into uncertainty, resolution, and reliability.

  [Proper scoring rules][1] measure the quality of probabilistic predictions;
  any proper scoring rule admits a [unique decomposition][2] as
  `Score = Uncertainty - Resolution + Reliability`, where:

  * `Uncertainty`, is a generalized entropy of the average predictive
    distribution; it can both be positive or negative.
  * `Resolution`, is a generalized variance of individual predictive
    distributions; it is always non-negative.  Difference in predictions reveal
    information, that is why a larger resolution improves the predictive score.
  * `Reliability`, a measure of calibration of predictions against the true
    frequency of events.  It is always non-negative and a lower value here
    indicates better calibration.

  This method estimates the above decomposition for the case of the Brier
  scoring rule for discrete outcomes.  For this, we need to discretize the space
  of probability distributions; we choose a simple partition of the space into
  `nlabels` events: given a distribution `p` over `nlabels` outcomes, the index
  `k` for which `p_k > p_i` for all `i != k` determines the discretization
  outcome; that is, `p in M_k`, where `M_k` is the set of all distributions for
  which `p_k` is the largest value among all probabilities.

  The estimation error of each component is O(k/n), where n is the number
  of instances and k is the number of labels.  There may be an error of this
  order when compared to `brier_score`.

  #### References
  [1]: Tilmann Gneiting, Adrian E. Raftery.
       Strictly Proper Scoring Rules, Prediction, and Estimation.
       Journal of the American Statistical Association, Vol. 102, 2007.
       https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf
  [2]: Jochen Broecker.  Reliability, sufficiency, and the decomposition of
       proper scores.
       Quarterly Journal of the Royal Meteorological Society, Vol. 135, 2009.
       https://rmets.onlinelibrary.wiley.com/doi/epdf/10.1002/qj.456

  Args:
    labels: Tensor, (n,), with tf.int32 or tf.int64 elements containing ground
      truth class labels in the range [0,nlabels].
    logits: Tensor, (n, nlabels), with logits for n instances and nlabels.
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    uncertainty: Tensor, scalar, the uncertainty component of the
      decomposition.
    resolution: Tensor, scalar, the resolution component of the decomposition.
    reliability: Tensor, scalar, the reliability component of the
      decomposition.
  """
    with tf.name_scope(name or 'brier_decomposition'):
        labels = tf.convert_to_tensor(labels)
        logits = tf.convert_to_tensor(logits)
        num_classes = logits.shape[-1]

        # Compute pbar, the average distribution
        pred_class = tf.argmax(logits, axis=-1, output_type=labels.dtype)

        if tensorshape_util.rank(logits.shape) > 2:
            shape_as_list = tensorshape_util.as_list(logits.shape)
            flatten, unflatten = _make_flatten_unflatten_fns(
                shape_as_list[:-2])

            def fn_to_map(args):
                yhat, y = args
                return tf.math.confusion_matrix(yhat,
                                                y,
                                                num_classes=num_classes,
                                                dtype=logits.dtype)

            confusion_matrix = tf.map_fn(
                fn_to_map,
                [flatten(pred_class), flatten(labels)],
                fn_output_signature=logits.dtype)
            confusion_matrix = unflatten(confusion_matrix)
        else:
            confusion_matrix = tf.math.confusion_matrix(
                pred_class,
                labels,
                num_classes=num_classes,
                dtype=logits.dtype)

        dist_weights = tf.reduce_sum(confusion_matrix, axis=-1)
        dist_weights /= tf.reduce_sum(dist_weights, axis=-1, keepdims=True)
        pbar = tf.reduce_sum(confusion_matrix, axis=-2)
        pbar /= tf.reduce_sum(pbar, axis=-1, keepdims=True)

        eps = np.finfo(dtype_util.as_numpy_dtype(confusion_matrix.dtype)).eps
        # dist_mean[k,:] contains the empirical distribution for the set M_k
        # Some outcomes may not realize, corresponding to dist_weights[k] = 0
        dist_mean = confusion_matrix / (
            eps + tf.reduce_sum(confusion_matrix, axis=-1, keepdims=True))

        # Uncertainty: quadratic entropy of the average label distribution
        uncertainty = -tf.reduce_sum(tf.square(pbar), axis=-1)

        # Resolution: expected quadratic divergence of predictive to mean
        resolution = tf.square(tf.expand_dims(pbar, -1) - dist_mean)
        resolution = tf.reduce_sum(dist_weights *
                                   tf.reduce_sum(resolution, axis=-1),
                                   axis=-1)

        # Reliability: expected quadratic divergence of predictive to true
        if tensorshape_util.rank(logits.shape) > 2:
            # TODO(b/139094519): Avoid using tf.map_fn here.
            prob_true = tf.map_fn(
                lambda args: tf.gather(args[0], args[1]),
                [flatten(dist_mean), flatten(pred_class)],
                fn_output_signature=dist_mean.dtype)
            prob_true = unflatten(prob_true)
        else:
            prob_true = tf.gather(dist_mean, pred_class, axis=0)
        log_prob_true = tf.math.log(prob_true)

        log_prob_pred = logits - tf.math.reduce_logsumexp(
            logits, axis=-1, keepdims=True)

        log_reliability = _reduce_log_l2_exp(log_prob_pred,
                                             log_prob_true,
                                             axis=-1)
        log_reliability = tf.math.reduce_logsumexp(
            log_reliability,
            axis=-1,
        )

        num_samples = tf.cast(tf.shape(logits)[-2], logits.dtype)
        reliability = tf.exp(log_reliability - tf.math.log(num_samples))

        return uncertainty, resolution, reliability
Example #12
0
    def testPreconditionerComputedCorrectly(self):
        """Test that SGLD step is computed correctly for a 3D Gaussian energy."""
        if tf.executing_eagerly():
            return

        with self.cached_session():
            dtype = np.float32
            # Target function is the energy function of normal distribution
            true_mean = dtype([0, 0, 0])
            true_cov = dtype([[1, 0.25, 0.25], [0.25, 1, 0.25],
                              [0.25, 0.25, 1]])
            # Target distribution is defined through the Cholesky decomposition
            chol = tf.linalg.cholesky(true_cov)
            target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)
            var_1 = tf.Variable(name='var_1', initial_value=[1., 1.])
            var_2 = tf.Variable(name='var_2', initial_value=[1.])

            var = [var_1, var_2]

            # Set up the learning rate and the optimizer
            learning_rate = .5
            optimizer_kernel = tfp.optimizer.StochasticGradientLangevinDynamics(
                learning_rate=learning_rate, burnin=1)

            # Target function
            def target_fn(x, y):
                # Stack the input tensors together
                z = tf.concat([x, y], axis=-1) - true_mean
                return -target.log_prob(z)

            grads = tf.gradients(ys=target_fn(*var), xs=var)

            # Update value of `var` with one iteration of the SGLD (without the
            # normal perturbation, since `burnin > 0`)
            step = optimizer_kernel.apply_gradients(zip(grads, var))

            # True theoretical value of `var` after one iteration
            decay_tensor = tf.cast(optimizer_kernel._decay_tensor,
                                   var[0].dtype)
            diagonal_bias = tf.cast(optimizer_kernel._diagonal_bias,
                                    var[0].dtype)
            learning_rate = tf.cast(optimizer_kernel._learning_rate,
                                    var[0].dtype)
            velocity = [(decay_tensor * tf.ones_like(v) +
                         (1 - decay_tensor) * tf.square(g))
                        for v, g in zip(var, grads)]
            preconditioner = [
                tf.math.rsqrt(vel + diagonal_bias) for vel in velocity
            ]
            # Compute second order gradients
            _, grad_grads = diag_jacobian(xs=var, ys=grads)
            # Compute gradient of the preconditioner (compute the gradient manually)
            preconditioner_grads = [
                -(g * g_g * (1. - decay_tensor) * p**3.)
                for g, g_g, p in zip(grads, grad_grads, preconditioner)
            ]

            # True theoretical value of `var` after one iteration
            var_true = [
                v - learning_rate * 0.5 * (p * g - p_g) for v, p, g, p_g in
                zip(var, preconditioner, grads, preconditioner_grads)
            ]
            self.evaluate(tf1.global_variables_initializer())
            var_true_ = self.evaluate(var_true)
            self.evaluate(step)
            var_ = self.evaluate(var)  # new `var` after one SGLD step
            self.assertAllClose(var_true_, var_, atol=0.001, rtol=0.001)
Example #13
0
 def target_log_prob_fn(x):
   counter[0] += 1
   return -tf.square(x), []
Example #14
0
 def call(self, inputs):
     if tf.reduce_sum(inputs) > 0:
         return tf.sqrt(inputs)
     else:
         return tf.square(inputs)
Example #15
0
def _normal_pdf(x):
  two_pi = tf.convert_to_tensor(2 * np.pi, dtype=x.dtype)
  return tf.math.rsqrt(two_pi) * tf.exp(-0.5 * tf.square(x))
Example #16
0
    def _log_prob(self, x):
        if self.input_output_cholesky:
            x_sqrt = x
        else:
            # Complexity: O(nbk**3)
            x_sqrt = tf.linalg.cholesky(x)

        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        x_ndims = tf.rank(input=x_sqrt)
        num_singleton_axes_to_prepend = (
            tf.maximum(tf.size(input=batch_shape) + 2, x_ndims) - x_ndims)
        x_with_prepended_singletons_shape = tf.concat([
            tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32),
            tf.shape(input=x_sqrt)
        ], 0)
        x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape)
        ndims = tf.rank(x_sqrt)
        # sample_ndims = ndims - batch_ndims - event_ndims
        sample_ndims = ndims - tf.size(input=batch_shape) - 2
        sample_shape = tf.shape(input=x_sqrt)[:sample_ndims]

        # We need to be able to pre-multiply each matrix by its corresponding
        # batch scale matrix. Since a Distribution Tensor supports multiple
        # samples per batch, this means we need to reshape the input matrix `x`
        # so that the first b dimensions are batch dimensions and the last two
        # are of shape [dimension, dimensions*number_of_samples]. Doing these
        # gymnastics allows us to do a batch_solve.
        #
        # After we're done with sqrt_solve (the batch operation) we need to undo
        # this reshaping so what we're left with is a Tensor partitionable by
        # sample, batch, event dimensions.

        # Complexity: O(nbk**2) since transpose must access every element.
        scale_sqrt_inv_x_sqrt = x_sqrt
        perm = tf.concat(
            [tf.range(sample_ndims, ndims),
             tf.range(0, sample_ndims)], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)
        last_dim_size = (
            tf.cast(self.dimension, dtype=tf.int32) * tf.reduce_prod(
                input_tensor=x_with_prepended_singletons_shape[:sample_ndims]))
        shape = tf.concat([
            x_with_prepended_singletons_shape[sample_ndims:-2],
            [tf.cast(self.dimension, dtype=tf.int32), last_dim_size]
        ],
                          axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)

        # Complexity: O(nbM*k) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so
        # this step has complexity O(nbk^3).
        scale_sqrt_inv_x_sqrt = self.scale_operator.solve(
            scale_sqrt_inv_x_sqrt)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat([
            tf.shape(input=scale_sqrt_inv_x_sqrt)[:-2], event_shape,
            sample_shape
        ],
                          axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)
        perm = tf.concat([
            tf.range(ndims - sample_ndims, ndims),
            tf.range(0, ndims - sample_ndims)
        ], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)

        # Write V = SS', X = LL'. Then:
        # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
        #              = tr[inv(S) L L' inv(S)']
        #              = tr[(inv(S) L) (inv(S) L)']
        #              = sum_{ik} (inv(S) L)_{ik}**2
        # The second equality follows from the cyclic permutation property.
        # Complexity: O(nbk**2)
        trace_scale_inv_x = tf.reduce_sum(
            input_tensor=tf.square(scale_sqrt_inv_x_sqrt), axis=[-2, -1])

        # Complexity: O(nbk)
        half_log_det_x = tf.reduce_sum(input_tensor=tf.math.log(
            tf.linalg.diag_part(x_sqrt)),
                                       axis=[-1])

        # Complexity: O(nbk**2)
        log_prob = ((self.df - self.dimension - 1.) * half_log_det_x -
                    0.5 * trace_scale_inv_x - self.log_normalization())

        # Set shape hints.
        # Try to merge what we know from the input x with what we know from the
        # parameters of this distribution.
        if tensorshape_util.rank(
                x.shape) is not None and tensorshape_util.rank(
                    self.batch_shape) is not None:
            tensorshape_util.set_shape(
                log_prob,
                tf.broadcast_static_shape(x.shape[:-2], self.batch_shape))

        return log_prob
Example #17
0
def _normal_log_pdf(x):
  two_pi = tf.convert_to_tensor(2 * np.pi, dtype=x.dtype)
  return -0.5 * (tf.math.log(two_pi) + tf.square(x))
Example #18
0
    def sampler_loop_body(previous_sample, _):
        """Runs one sampler iteration, resampling all model variables."""

        (weights_seed, level_seed, observation_noise_scale_seed,
         level_scale_seed,
         loop_seed) = samplers.split_seed(previous_sample.seed,
                                          n=5,
                                          salt='sampler_loop_body')
        # Preserve backward-compatible seed behavior by splitting slope separately.
        slope_scale_seed, = samplers.split_seed(previous_sample.seed,
                                                n=1,
                                                salt='sampler_loop_body_slope')

        if regression_component:
            # We encourage a reasonable initialization by sampling the weights first,
            # so at the first step they are regressed directly against the observed
            # time series. If we instead sampled the level first it might 'explain
            # away' some observed variation that we would ultimately prefer to explain
            # through the regression weights, because the level can represent
            # arbitrary variation, while the weights are limited to representing
            # variation in the subspace given by the design matrix.
            if model_has_spike_slab_regression:
                if experimental_use_weight_adjustment:
                    previous_observation_noise_variance = tf.square(
                        previous_sample.observation_noise_scale)
                else:
                    previous_observation_noise_variance = 1.
                targets = tf.where(
                    is_missing, tf.zeros_like(observed_time_series),
                    observed_time_series - previous_sample.level)
                (observation_noise_variance, weights
                 ) = spike_and_slab_sampler.sample_noise_variance_and_weights(
                     initial_nonzeros=tf.math.logical_or(
                         tf.not_equal(previous_sample.weights, 0.),
                         pin_to_nonzero),
                     previous_observation_noise_variance=
                     previous_observation_noise_variance,
                     targets=targets,
                     seed=weights_seed)
                observation_noise_scale = tf.sqrt(observation_noise_variance)

            else:
                weights = _resample_weights(
                    design_matrix=design_matrix,
                    target_residuals=observed_time_series -
                    previous_sample.level,
                    observation_noise_scale=previous_sample.
                    observation_noise_scale,
                    weights_prior_scale=weights_prior_scale,
                    seed=weights_seed)
                # Noise scale will be resampled below.
                observation_noise_scale = previous_sample.observation_noise_scale

            regression_residuals = observed_time_series - tf.linalg.matvec(
                design_matrix, weights)
        else:
            # If there is no regression, then the entire timeseries is a residual.
            regression_residuals = observed_time_series
            # Noise scale will be resampled below.
            observation_noise_scale = previous_sample.observation_noise_scale
            weights = previous_sample.weights

        latents = resample_latents(
            observed_residuals=regression_residuals,
            level_scale=previous_sample.level_scale,
            slope_scale=previous_sample.slope_scale
            if model_has_slope else None,
            observation_noise_scale=observation_noise_scale,
            initial_state_prior=level_component.initial_state_prior,
            is_missing=is_missing,
            seed=level_seed)
        level = latents[..., 0]
        level_residuals = level[..., 1:] - level[..., :-1]
        if model_has_slope:
            slope = latents[..., 1]
            level_residuals -= slope[..., :-1]
            slope_residuals = slope[..., 1:] - slope[..., :-1]

        # Estimate level scale from the empirical changes in level.
        level_scale = resample_scale(prior=level_scale_variance_prior,
                                     observed_residuals=level_residuals,
                                     is_missing=None,
                                     seed=level_scale_seed)
        if model_has_slope:
            slope_scale = resample_scale(prior=slope_scale_variance_prior,
                                         observed_residuals=slope_residuals,
                                         is_missing=None,
                                         seed=slope_scale_seed)
        if not (regression_component and model_has_spike_slab_regression):
            # Estimate noise scale from the residuals.
            observation_noise_scale = resample_scale(
                prior=observation_noise_variance_prior,
                observed_residuals=regression_residuals - level,
                is_missing=is_missing,
                seed=observation_noise_scale_seed)

        return GibbsSamplerState(
            observation_noise_scale=observation_noise_scale,
            level_scale=level_scale,
            slope_scale=(slope_scale
                         if model_has_slope else previous_sample.slope_scale),
            weights=weights,
            level=level,
            slope=(slope if model_has_slope else previous_sample.slope),
            seed=loop_seed)
Example #19
0
 def _variance(self):
   return self.concentration / tf.square(self.rate)
Example #20
0
 def compute_brier(labels_, logits_):
   probs_ = tf.math.softmax(logits_, axis=1)
   _, nlabels = probs_.shape
   plabel = tf.reduce_sum(tf.one_hot(labels_, nlabels) * probs_, axis=1)
   brier = tf.reduce_sum(tf.square(probs_), axis=1) - 2.0 * plabel
   return tf.reduce_mean(brier)
Example #21
0
def brier_decomposition(labels=None, logits=None, probabilities=None):
    r"""Decompose the Brier score into uncertainty, resolution, and reliability.

  [Proper scoring rules][1] measure the quality of probabilistic predictions;
  any proper scoring rule admits a [unique decomposition][2] as
  `Score = Uncertainty - Resolution + Reliability`, where:

  * `Uncertainty`, is a generalized entropy of the average predictive
    distribution; it can both be positive or negative.
  * `Resolution`, is a generalized variance of individual predictive
    distributions; it is always non-negative.  Difference in predictions reveal
    information, that is why a larger resolution improves the predictive score.
  * `Reliability`, a measure of calibration of predictions against the true
    frequency of events.  It is always non-negative and a lower value here
    indicates better calibration.

  This method estimates the above decomposition for the case of the Brier
  scoring rule for discrete outcomes.  For this, we need to discretize the space
  of probability distributions; we choose a simple partition of the space into
  `nlabels` events: given a distribution `p` over `nlabels` outcomes, the index
  `k` for which `p_k > p_i` for all `i != k` determines the discretization
  outcome; that is, `p in M_k`, where `M_k` is the set of all distributions for
  which `p_k` is the largest value among all probabilities.

  The estimation error of each component is O(k/n), where n is the number
  of instances and k is the number of labels.  There may be an error of this
  order when compared to `brier_score`.

  #### References
  [1]: Tilmann Gneiting, Adrian E. Raftery.
       Strictly Proper Scoring Rules, Prediction, and Estimation.
       Journal of the American Statistical Association, Vol. 102, 2007.
       https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf
  [2]: Jochen Broecker.  Reliability, sufficiency, and the decomposition of
       proper scores.
       Quarterly Journal of the Royal Meteorological Society, Vol. 135, 2009.
       https://rmets.onlinelibrary.wiley.com/doi/epdf/10.1002/qj.456

  Args:
    labels: Tensor, (n,), with tf.int32 or tf.int64 elements containing ground
      truth class labels in the range [0,nlabels].
    logits: Tensor, (n, nlabels), with logits for n instances and nlabels.
    probabilities: Tensor, (n, nlabels), with predictive probability
      distribution (alternative to logits argument).

  Returns:
    uncertainty: Tensor, scalar, the uncertainty component of the
      decomposition.
    resolution: Tensor, scalar, the resolution component of the decomposition.
    reliability: Tensor, scalar, the reliability component of the
      decomposition.
  """
    if (logits is None) == (probabilities is None):
        raise ValueError(
            'brier_decomposition expects exactly one of logits or probabilities.'
        )
    if probabilities is None:
        probabilities = scipy.special.softmax(logits, axis=1)
    _, nlabels = probabilities.shape  # Implicit rank check.

    # Compute pbar, the average distribution
    pred_class = tf.argmax(probabilities, axis=1, output_type=tf.int32)
    confusion_matrix = tf.math.confusion_matrix(pred_class,
                                                labels,
                                                nlabels,
                                                dtype=tf.float32)
    dist_weights = tf.reduce_sum(confusion_matrix, axis=1)
    dist_weights /= tf.reduce_sum(dist_weights)
    pbar = tf.reduce_sum(confusion_matrix, axis=0)
    pbar /= tf.reduce_sum(pbar)

    # dist_mean[k,:] contains the empirical distribution for the set M_k
    # Some outcomes may not realize, corresponding to dist_weights[k] = 0
    dist_mean = confusion_matrix / tf.expand_dims(
        tf.reduce_sum(confusion_matrix, axis=1) + 1.0e-7, 1)

    # Uncertainty: quadratic entropy of the average label distribution
    uncertainty = -tf.reduce_sum(tf.square(pbar))

    # Resolution: expected quadratic divergence of predictive to mean
    resolution = tf.square(tf.expand_dims(pbar, 1) - dist_mean)
    resolution = tf.reduce_sum(dist_weights *
                               tf.reduce_sum(resolution, axis=1))

    # Reliability: expected quadratic divergence of predictive to true
    prob_true = tf.gather(dist_mean, pred_class, axis=0)
    reliability = tf.reduce_sum(tf.square(prob_true - probabilities), axis=1)
    reliability = tf.reduce_mean(reliability)

    return uncertainty, resolution, reliability
def lossfun(x, alpha, scale, approximate=False, epsilon=1e-6):
    r"""Implements the general form of the loss.

  This implements the rho(x, \alpha, c) function described in "A General and
  Adaptive Robust Loss Function", Jonathan T. Barron,
  https://arxiv.org/abs/1701.03077.

  Args:
    x: The residual for which the loss is being computed. x can have any shape,
      and alpha and scale will be broadcasted to match x's shape if necessary.
      Must be a tensorflow tensor or numpy array of floats.
    alpha: The shape parameter of the loss (\alpha in the paper), where more
      negative values produce a loss with more robust behavior (outliers "cost"
      less), and more positive values produce a loss with less robust behavior
      (outliers are penalized more heavily). Alpha can be any value in
      [-infinity, infinity], but the gradient of the loss with respect to alpha
      is 0 at -infinity, infinity, 0, and 2. Must be a tensorflow tensor or
      numpy array of floats with the same precision as `x`. Varying alpha allows
      for smooth interpolation between a number of discrete robust losses:
      alpha=-Infinity: Welsch/Leclerc Loss.
      alpha=-2: Geman-McClure loss.
      alpha=0: Cauchy/Lortentzian loss.
      alpha=1: Charbonnier/pseudo-Huber loss.
      alpha=2: L2 loss.
    scale: The scale parameter of the loss. When |x| < scale, the loss is an
      L2-like quadratic bowl, and when |x| > scale the loss function takes on a
      different shape according to alpha. Must be a tensorflow tensor or numpy
      array of single-precision floats.
    approximate: a bool, where if True, this function returns an approximate and
      faster form of the loss, as described in the appendix of the paper. This
      approximation holds well everywhere except as x and alpha approach zero.
    epsilon: A float that determines how inaccurate the "approximate" version of
      the loss will be. Larger values are less accurate but more numerically
      stable. Must be great than single-precision machine epsilon.

  Returns:
    The losses for each element of x, in the same shape as x. This is returned
    as a TensorFlow graph node of single precision floats.
  """
    # `scale` and `alpha` must have the same type as `x`.
    float_dtype = x.dtype
    tf.debugging.assert_type(scale, float_dtype)
    tf.debugging.assert_type(alpha, float_dtype)
    # `scale` must be > 0.
    assert_ops = [tf.Assert(tf.reduce_all(tf.greater(scale, 0.)), [scale])]
    with tf.control_dependencies(assert_ops):
        # Broadcast `alpha` and `scale` to have the same shape as `x`.
        alpha = tf.broadcast_to(alpha, tf.shape(x))
        scale = tf.broadcast_to(scale, tf.shape(x))

        if approximate:
            # `epsilon` must be greater than single-precision machine epsilon.
            assert epsilon > np.finfo(np.float32).eps
            # Compute an approximate form of the loss which is faster, but innacurate
            # when x and alpha are near zero.
            b = tf.abs(alpha - tf.cast(2., float_dtype)) + epsilon
            d = tf.where(tf.greater_equal(alpha, 0.), alpha + epsilon,
                         alpha - epsilon)
            loss = (b / d) * (tf.pow(tf.square(x / scale) / b + 1., 0.5 * d) -
                              1.)
        else:
            # Compute the exact loss.

            # This will be used repeatedly.
            squared_scaled_x = tf.square(x / scale)

            # The loss when alpha == 2.
            loss_two = 0.5 * squared_scaled_x
            # The loss when alpha == 0.
            loss_zero = util.log1p_safe(0.5 * squared_scaled_x)
            # The loss when alpha == -infinity.
            loss_neginf = -tf.math.expm1(-0.5 * squared_scaled_x)
            # The loss when alpha == +infinity.
            loss_posinf = util.expm1_safe(0.5 * squared_scaled_x)

            # The loss when not in one of the above special cases.
            machine_epsilon = tf.cast(np.finfo(np.float32).eps, float_dtype)
            # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by.
            beta_safe = tf.maximum(machine_epsilon, tf.abs(alpha - 2.))
            # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by.
            alpha_safe = tf.where(tf.greater_equal(alpha, 0.),
                                  tf.ones_like(alpha),
                                  -tf.ones_like(alpha)) * tf.maximum(
                                      machine_epsilon, tf.abs(alpha))
            loss_otherwise = (beta_safe / alpha_safe) * (
                tf.pow(squared_scaled_x / beta_safe + 1., 0.5 * alpha) - 1.)

            # Select which of the cases of the loss to return.
            loss = tf.where(
                tf.equal(alpha, -tf.cast(float('inf'), float_dtype)),
                loss_neginf,
                tf.where(
                    tf.equal(alpha, 0.), loss_zero,
                    tf.where(
                        tf.equal(alpha, 2.), loss_two,
                        tf.where(
                            tf.equal(alpha, tf.cast(float('inf'),
                                                    float_dtype)), loss_posinf,
                            loss_otherwise))))

        return loss
  def solve_nu_zeta(self,
                    dataset: dataset_lib.OffpolicyDataset,
                    target_policy: tf_policy.TFPolicy,
                    regularizer: float = 1e-6):
    """Solves for density ratios and then approximates target policy value.

    Args:
      dataset: The dataset to sample experience from.
      target_policy: The policy whose value we want to estimate.
      regularizer: A small constant to add to matrices before inverting them or
        to floats before taking square root.

    Returns:
      Estimated average per-step reward of the target policy.
    """

    if not hasattr(self, '_td_mat'):
      # Set up env_steps.
      episodes, valid_steps = dataset.get_all_episodes(
          limit=self._limit_episodes)
      total_num_steps_per_episode = tf.shape(valid_steps)[1] - 1
      num_episodes = tf.shape(valid_steps)[0]
      num_samples = num_episodes * total_num_steps_per_episode
      valid_and_not_last = tf.logical_and(valid_steps, episodes.discount > 0)
      valid_indices = tf.squeeze(
          tf.where(tf.reshape(valid_and_not_last[:, :-1], [-1])))

      initial_env_step = tf.nest.map_structure(
          lambda t: tf.squeeze(
              tf.reshape(
                  tf.repeat(
                      t[:, 0:1, ...],
                      axis=1,
                      repeats=total_num_steps_per_episode), [num_samples, -1])),
          episodes)
      initial_env_step = tf.nest.map_structure(
          lambda t: tf.gather(t, valid_indices), initial_env_step)
      tfagents_initial_env_step = dataset_lib.convert_to_tfagents_timestep(
          initial_env_step)

      env_step = tf.nest.map_structure(
          lambda t: tf.squeeze(
              tf.reshape(t[:, 0:total_num_steps_per_episode, ...],
                         [num_samples, -1])), episodes)
      env_step = tf.nest.map_structure(lambda t: tf.gather(t, valid_indices),
                                       env_step)
      tfagents_env_step = dataset_lib.convert_to_tfagents_timestep(env_step)

      next_env_step = tf.nest.map_structure(
          lambda t: tf.squeeze(
              tf.reshape(t[:, 1:total_num_steps_per_episode + 1, ...],
                         [num_samples, -1])), episodes)
      next_env_step = tf.nest.map_structure(
          lambda t: tf.gather(t, valid_indices), next_env_step)
      tfagents_next_env_step = dataset_lib.convert_to_tfagents_timestep(
          next_env_step)

      # get probabilities
      initial_target_probs = target_policy.distribution(
          tfagents_initial_env_step).action.probs_parameter()
      next_target_probs = target_policy.distribution(
          tfagents_next_env_step).action.probs_parameter()

      # First, get the nu_loss and data weights
      #current_nu_loss = self._get_nu_loss(initial_env_step, env_step,
      #                                    next_env_step, target_policy)
      #data_weight, _ = self._get_weights(current_nu_loss)

      # # debug only and to reproduce dual dice result, DELETE
      # data_weight = tf.ones_like(data_weight)

      state_action_count = self._get_state_action_counts(env_step)
      counts = tf.reduce_sum(tf.one_hot(state_action_count, self._dimension), 0)
      gamma_sample = tf.pow(self._gamma, tf.cast(env_step.step_num, tf.float32))

      # # debug only and to reproduce dual dice result, DELETE
      # gamma_sample = tf.ones_like(gamma_sample)

      # now we need to expand_dims to include action space in extra dimensions
      #data_weights = tf.reshape(data_weight, [-1, self._num_limits])
      # both are data sample weights for L2 problem, needs to be normalized later
      #gamma_data_weights = tf.reshape(gamma_sample, [-1, 1]) * data_weights

      initial_states = tf.tile(
          tf.reshape(initial_env_step.observation, [-1, 1]),
          [1, self._num_actions])
      initial_actions = tf.tile(
          tf.reshape(tf.range(self._num_actions), [1, -1]),
          [initial_env_step.observation.shape[0], 1])
      initial_nu_indices = self._get_index(initial_states, initial_actions)

      # linear term w.r.t. initial distribution
      #b_vec_2 = tf.stack([
      #    tf.reduce_sum(
      #        tf.reshape(
      #            data_weights[:, itr] / tf.reduce_sum(data_weights[:, itr]),
      #            [-1, 1]) * tf.reduce_sum(
      #                tf.one_hot(initial_nu_indices, self._dimension) *
      #                (1 - self._gamma) *
      #                tf.expand_dims(initial_target_probs, axis=-1),
      #                axis=1),
      #        axis=0) for itr in range(self._num_limits)
      #],
      #                   axis=0)

      next_states = tf.tile(
          tf.reshape(next_env_step.observation, [-1, 1]),
          [1, self._num_actions])
      next_actions = tf.tile(
          tf.reshape(tf.range(self._num_actions), [1, -1]),
          [next_env_step.observation.shape[0], 1])
      next_nu_indices = self._get_index(next_states, next_actions)
      next_nu_indices = tf.where(
          tf.expand_dims(next_env_step.is_absorbing(), -1),
          -1 * tf.ones_like(next_nu_indices), next_nu_indices)

      nu_indices = self._get_index(env_step.observation, env_step.action)

      target_log_probabilities = target_policy.distribution(
          tfagents_env_step).action.log_prob(env_step.action)
      if not self._solve_for_state_action_ratio:
        policy_ratio = tf.exp(target_log_probabilities -
                              env_step.get_log_probability())
      else:
        policy_ratio = tf.ones([
            target_log_probabilities.shape[0],
        ])
      policy_ratios = tf.tile(
          tf.reshape(policy_ratio, [-1, 1]), [1, self._num_actions])

      # the tabular feature vector
      a_vec = tf.one_hot(nu_indices, self._dimension) - tf.reduce_sum(
          self._gamma *
          tf.expand_dims(next_target_probs * policy_ratios, axis=-1) *
          tf.one_hot(next_nu_indices, self._dimension),
          axis=1)

      # linear term w.r.t. reward
      #b_vec_1 = tf.stack([
      #    tf.reduce_sum(
      #        tf.reshape(
      #            (gamma_data_weights[:, itr] /
      #             tf.reduce_sum(gamma_data_weights[:, itr])) * self._reward_fn(env_step), #/
      #            #tf.cast(state_action_count, tf.float32),
      #            [-1, 1]) * a_vec,
      #        axis=0) for itr in range(self._num_limits)
      #],
      #                   axis=0)
      # quadratic term of feature
      # Get weighted outer product by using einsum to save computing resource!
      #a_mat = tf.stack([
      #    tf.einsum(
      #        'ai, a, aj -> ij', a_vec,
      #        #1.0 / tf.cast(state_action_count, tf.float32),
      #        gamma_data_weights[:, itr] /
      #        tf.reduce_sum(gamma_data_weights[:, itr]),
      #        a_vec)
      #    for itr in range(self._num_limits)
      #],
      #                 axis=0)

      td_mat = tf.einsum('ai, a, aj -> ij',
                         tf.one_hot(nu_indices, self._dimension),
                         1.0 / tf.cast(state_action_count, tf.float32), a_vec)

      weighted_rewards = policy_ratio * self._reward_fn(env_step)

      bias = tf.reduce_sum(
          tf.one_hot(nu_indices, self._dimension) *
          tf.reshape(weighted_rewards, [-1, 1]) * 1.0 /
          tf.cast(state_action_count, tf.float32)[:, None],
          axis=0)

      # Initialize
      self._nu = np.ones_like(self._nu) * bias[:, None]
      self._nu2 = np.ones_like(self._nu2) * bias[:, None]

      self._a_vec = a_vec
      self._td_mat = td_mat
      self._bias = bias
      self._weighted_rewards = weighted_rewards
      self._state_action_count = state_action_count
      self._nu_indices = nu_indices
      self._initial_nu_indices = initial_nu_indices
      self._initial_target_probs = initial_target_probs
      self._gamma_sample = gamma_sample
      self._gamma_sample = tf.ones_like(gamma_sample)

    saddle_bellman_residuals = (
        tf.matmul(self._a_vec, self._nu) - self._weighted_rewards[:, None])
    saddle_bellman_residuals *= -1 * self._algae_alpha_sign
    saddle_zetas = tf.gather(self._zeta, self._nu_indices)
    saddle_initial_nu_values = tf.reduce_sum(  # Average over actions.
        self._initial_target_probs[:, :, None] *
        tf.gather(self._nu, self._initial_nu_indices),
        axis=1)
    saddle_init_nu_loss = ((1 - self._gamma) * saddle_initial_nu_values *
                           self._algae_alpha_sign)

    saddle_bellman_residuals2 = (
        tf.matmul(self._a_vec, self._nu2) - self._weighted_rewards[:, None])
    saddle_bellman_residuals2 *= 1 * self._algae_alpha_sign
    saddle_zetas2 = tf.gather(self._zeta2, self._nu_indices)
    saddle_initial_nu_values2 = tf.reduce_sum(  # Average over actions.
        self._initial_target_probs[:, :, None] *
        tf.gather(self._nu2, self._initial_nu_indices),
        axis=1)
    saddle_init_nu_loss2 = ((1 - self._gamma) * saddle_initial_nu_values2 * -1 *
                            self._algae_alpha_sign)

    saddle_loss = 0.5 * (
        saddle_init_nu_loss + saddle_bellman_residuals * saddle_zetas +
        -tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas) +
        -saddle_init_nu_loss2 + -saddle_bellman_residuals2 * saddle_zetas2 +
        tf.math.abs(self._algae_alpha) * 0.5 * tf.square(saddle_zetas2))
    # Binary search to find best alpha.
    left = tf.constant([-8., -8.])
    right = tf.constant([32., 32.])
    for _ in range(16):
      mid = 0.5 * (left + right)
      self._alpha.assign(mid)
      weights, log_weights = self._get_weights(saddle_loss *
                                               self._gamma_sample[:, None])

      divergence = self._compute_divergence(weights, log_weights)
      divergence_violation = divergence - self._two_sided_limit
      left = tf.where(divergence_violation > 0., mid, left)
      right = tf.where(divergence_violation > 0., right, mid)
    self._alpha.assign(0.5 * (left + right))
    weights, log_weights = self._get_weights(saddle_loss *
                                             self._gamma_sample[:, None])

    gamma_data_weights = tf.stop_gradient(weights * self._gamma_sample[:, None])
    #print(tf.concat([gamma_data_weights, saddle_loss], axis=-1))
    avg_saddle_loss = (
        tf.reduce_sum(gamma_data_weights * saddle_loss, axis=0) /
        tf.reduce_sum(gamma_data_weights, axis=0))

    weighted_state_action_count = tf.reduce_sum(
        tf.one_hot(self._nu_indices, self._dimension)[:, :, None] *
        weights[:, None, :],
        axis=0)
    weighted_state_action_count = tf.gather(weighted_state_action_count,
                                            self._nu_indices)
    my_td_mat = tf.einsum(
        'ai, ab, ab, aj -> bij',
        tf.one_hot(self._nu_indices, self._dimension),
        #1.0 / tf.cast(self._state_action_count, tf.float32),
        1.0 / weighted_state_action_count,
        weights,
        self._a_vec)
    my_bias = tf.reduce_sum(
        tf.transpose(weights)[:, :, None] *
        tf.one_hot(self._nu_indices, self._dimension)[None, :, :] *
        tf.reshape(self._weighted_rewards, [1, -1, 1]) *
        #1.0 / tf.cast(self._state_action_count, tf.float32)[None, :, None],
        1.0 / tf.transpose(weighted_state_action_count)[:, :, None],
        axis=1)

    #print('hello', saddle_initial_nu_values[:1], saddle_zetas[:3],
    #      self._nu[:2], my_bias[:, :2], saddle_loss[:4])

    with tf.GradientTape(
        watch_accessed_variables=False, persistent=True) as tape:
      tape.watch([self._nu, self._nu2, self._alpha])
      bellman_residuals = tf.matmul(
          my_td_mat,
          tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None]
      bellman_residuals = tf.transpose(tf.squeeze(bellman_residuals, -1))
      bellman_residuals = tf.gather(bellman_residuals, self._nu_indices)
      initial_nu_values = tf.reduce_sum(  # Average over actions.
          self._initial_target_probs[:, :, None] *
          tf.gather(self._nu, self._initial_nu_indices),
          axis=1)

      bellman_residuals *= self._algae_alpha_sign

      init_nu_loss = ((1 - self._gamma) * initial_nu_values *
                      self._algae_alpha_sign)

      nu_loss = (
          tf.math.square(bellman_residuals) / 2.0 +
          tf.math.abs(self._algae_alpha) * init_nu_loss)

      loss = (
          gamma_data_weights * nu_loss /
          tf.reduce_sum(gamma_data_weights, axis=0, keepdims=True))

      bellman_residuals2 = tf.matmul(
          my_td_mat,
          tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :, None]
      bellman_residuals2 = tf.transpose(tf.squeeze(bellman_residuals2, -1))
      bellman_residuals2 = tf.gather(bellman_residuals2, self._nu_indices)
      initial_nu_values2 = tf.reduce_sum(  # Average over actions.
          self._initial_target_probs[:, :, None] *
          tf.gather(self._nu2, self._initial_nu_indices),
          axis=1)

      bellman_residuals2 *= -1 * self._algae_alpha_sign

      init_nu_loss2 = ((1 - self._gamma) * initial_nu_values2 * -1 *
                       self._algae_alpha_sign)

      nu_loss2 = (
          tf.math.square(bellman_residuals2) / 2.0 +
          tf.math.abs(self._algae_alpha) * init_nu_loss2)

      loss2 = (
          gamma_data_weights * nu_loss2 /
          tf.reduce_sum(gamma_data_weights, axis=0, keepdims=True))

      divergence = self._compute_divergence(weights, log_weights)
      divergence_violation = divergence - self._two_sided_limit

      alpha_loss = (-tf.exp(self._alpha) *
                    tf.stop_gradient(divergence_violation))

      extra_loss = tf.reduce_sum(tf.math.square(self._nu[-1, :]))
      extra_loss2 = tf.reduce_sum(tf.math.square(self._nu2[-1, :]))
      nu_grad = tape.gradient(loss + extra_loss, [self._nu])[0]
      nu_grad2 = tape.gradient(loss2 + extra_loss2, [self._nu2])[0]
    avg_loss = tf.reduce_sum(
        0.5 * (loss - loss2) / tf.math.abs(self._algae_alpha), axis=0)
    nu_jacob = tape.jacobian(nu_grad, [self._nu])[0]
    nu_hess = tf.stack([nu_jacob[:, i, :, i] for i in range(self._num_limits)],
                       axis=0)

    nu_jacob2 = tape.jacobian(nu_grad2, [self._nu2])[0]
    nu_hess2 = tf.stack(
        [nu_jacob2[:, i, :, i] for i in range(self._num_limits)], axis=0)

    for idx, div in enumerate(divergence):
      tf.summary.scalar('divergence%d' % idx, div)

    #alpha_grads = tape.gradient(alpha_loss, [self._alpha])
    #alpha_grad_op = self._alpha_optimizer.apply_gradients(
    #    zip(alpha_grads, [self._alpha]))
    #self._alpha.assign(tf.minimum(8., tf.maximum(-8., self._alpha)))

    #print(self._alpha, tf.concat([weights, nu_loss], -1))
    #regularizer = 0.1
    nu_transformed = tf.transpose(
        tf.squeeze(
            tf.linalg.solve(nu_hess + regularizer * tf.eye(self._dimension),
                            tf.expand_dims(-tf.transpose(nu_grad), axis=-1))))
    self._nu = self._nu + 0.1 * nu_transformed
    nu_transformed2 = tf.transpose(
        tf.squeeze(
            tf.linalg.solve(nu_hess2 + regularizer * tf.eye(self._dimension),
                            tf.expand_dims(-tf.transpose(nu_grad2), axis=-1))))
    self._nu2 = self._nu2 + 0.1 * nu_transformed2

    print(avg_loss * self._algae_alpha_sign,
          avg_saddle_loss * self._algae_alpha_sign, self._nu[:2], divergence)
    #print(init_nu_loss[:8], init_nu_loss[-8:])
    #print(bellman_residuals[:8])
    #print(self._nu[:3], self._zeta[:3])

    zetas = tf.matmul(my_td_mat,
                      tf.transpose(self._nu)[:, :, None]) - my_bias[:, :, None]
    zetas = tf.transpose(tf.squeeze(zetas, -1))
    zetas *= -self._algae_alpha_sign
    zetas /= tf.math.abs(self._algae_alpha)
    self._zeta = self._zeta + 0.1 * (zetas - self._zeta)

    zetas2 = tf.matmul(my_td_mat,
                       tf.transpose(self._nu2)[:, :, None]) - my_bias[:, :,
                                                                      None]
    zetas2 = tf.transpose(tf.squeeze(zetas2, -1))
    zetas2 *= 1 * self._algae_alpha_sign
    zetas2 /= tf.math.abs(self._algae_alpha)
    self._zeta2 = self._zeta2 + 0.1 * (zetas2 - self._zeta2)

    #self._zeta = (
    #    tf.einsum('ij,ja-> ia', self._td_mat, self._nu) -
    #    tf.transpose(my_bias))
    #self._zeta *= -tf.reshape(self._algae_alpha_sign, [1, self._num_limits])
    #self._zeta /= tf.math.abs(self._algae_alpha)
    return [
        avg_saddle_loss * self._algae_alpha_sign,
        avg_loss * self._algae_alpha_sign, divergence
    ]
Example #24
0
def chees_criterion(previous_state,
                    proposed_state,
                    accept_prob,
                    validate_args=False,
                    experimental_shard_axis_names=None,
                    experimental_chain_axis_names=None):
    """The ChEES criterion from [1].

  ChEES stands for Change in the Estimator of the Expected Square.

  ```None
  ChEES = 1/4 E[(||x' - E[x]||**2 - ||x - E[x]||**2)**2],
  ```

  where `x` is the previous chain state, `x'` is the next chain state, and
  `||.||` is the L2 norm. Both expectations are with respect to the chain's
  stationary distribution. In practice, the inner expectation is replaced by the
  empirical mean across chains, so computing this criterion requires that at
  least 2 chains are present. The outer expectation is computed by the caller
  (e.g. in the `GradientBasedTrajectoryLengthAdaptation` kernel).

  This can be thought of as the standard expected squared jump distance (ESJD)
  criterion, except that the jump distance is computed in the space of centered
  squared L2 norms.

  Unlike ChEES, regular ESJD is maximized by perfectly anticorrelated proposals,
  which can give excellent mean estimates but terrible variance estimates;
  maximizing ChEES should give good estimates across a wider range of types of
  posterior expectations.

  Args:
    previous_state: (Possibly nested) floating point `Tensor`. The previous
      state of the HMC chain.
    proposed_state: (Possibly nested) floating point `Tensor`. The proposed
      state of the HMC chain.
    accept_prob: Floating `Tensor`. Probability of acceping the proposed state.
    validate_args: Whether to perform non-static argument validation.
    experimental_shard_axis_names: A structure of string names indicating how
      members of the state are sharded.
    experimental_chain_axis_names: A string or list of string names indicating
      how batches of chains are sharded.

  Returns:
    chees: The value of the ChEES criterion.

  Raises:
    ValueError: If `accept_prob` indicates that there are fewer than 2 chains.

  #### References

  [1]: Hoffman, M., Radul, A., & Sountsov, P. (2020). An Adaptive MCMC Scheme
       for Setting Trajectory Lengths in Hamiltonian Monte Carlo. In
       preparation.

  """
    batch_ndims = ps.rank(accept_prob)
    batch_axes = ps.range(batch_ndims, dtype=tf.int32)
    experimental_chain_axis_names = distribute_lib.canonicalize_named_axis(
        experimental_chain_axis_names)
    # Number of total chains is local batch size * distributed axis size
    local_axis_size = ps.maximum(ps.size(accept_prob), 1)
    distributed_axis_size = int(
        ps.reduce_prod([
            distribute_lib.get_axis_size(a)
            for a in experimental_chain_axis_names
        ]))
    num_chains = local_axis_size * distributed_axis_size
    num_chains_ = tf.get_static_value(num_chains)
    if num_chains_ is not None:
        if num_chains_ < 2:
            raise ValueError(
                'chees_criterion requires at least 2 chains. Got: {}'.format(
                    num_chains_))
    elif validate_args:
        with tf.control_dependencies([
                assert_util.assert_greater_equal(
                    num_chains, 2,
                    'chees_criterion requires at least 2 chains.')
        ]):
            previous_state = tf.nest.map_structure(tf.identity, previous_state)

    def _center_previous_state(x):
        # The empirical mean here is a stand-in for the true mean, so we drop the
        # gradient that flows through this term.
        x_mean = _reduce_mean_with_axes(x, batch_axes,
                                        experimental_chain_axis_names)
        return x - tf.stop_gradient(x_mean)

    def _center_proposed_state(x):
        # The empirical mean here is a stand-in for the true mean, so we drop the
        # gradient that flows through this term. The goal here is to get a reliable
        # diagnostic of the unrelying dynamics, rather than incorporating the effect
        # of the MetropolisHastings correction.
        # TODO(mhoffman): Needs more experimentation.
        expanded_accept_prob = bu.left_justified_expand_dims_like(
            accept_prob, x)

        # accept_prob is zero when x is NaN, but we still want to sanitize such
        # values.
        x_safe = tf.where(tf.math.is_finite(x), x, tf.zeros_like(x))
        # If all accept_prob's are zero, the x_center will have a nonsense value,
        # but we'll discard the resultant gradients later on, so it's fine.
        x_center = (
            _reduce_sum_with_axes(expanded_accept_prob * x_safe, batch_axes,
                                  experimental_chain_axis_names) /
            (_reduce_sum_with_axes(expanded_accept_prob, batch_axes,
                                   experimental_chain_axis_names) + 1e-20))

        return x - tf.stop_gradient(x_center)

    def _sum_event_part(x, shard_axes=None):
        event_axes = ps.range(batch_ndims, ps.rank(x))
        return distribute_lib.psum(tf.reduce_sum(x, axis=event_axes),
                                   shard_axes)

    def _sum_event(x):
        event_parts = _map_structure_up_to_with_axes(
            x,
            _sum_event_part,
            x,
            experimental_shard_axis_names=experimental_shard_axis_names)
        return sum(tf.nest.flatten(event_parts))

    def _square(x):
        return tf.nest.map_structure(tf.square, x)

    def _sub(x, y):
        return tf.nest.map_structure(lambda x, y: x - y, x, y)

    previous_state = tf.nest.map_structure(_center_previous_state,
                                           previous_state)
    proposed_state = tf.nest.map_structure(_center_proposed_state,
                                           proposed_state)
    chees = 0.25 * tf.square(
        _sum_event(_sub(_square(proposed_state), _square(previous_state))))
    return chees
Example #25
0
    def _get_mean_and_variance(self, x):
        """Cross-replica mean and variance."""
        replica_context = tf.distribute.get_replica_context()

        if replica_context is not None:
            num_replicas_in_sync = replica_context.num_replicas_in_sync
            if num_replicas_in_sync <= 8:
                group_assignment = None
                num_replicas_per_group = tf.cast(num_replicas_in_sync,
                                                 tf.float32)
            else:
                num_replicas_per_group = max(8, num_replicas_in_sync // 8)
                group_assignment = np.arange(num_replicas_in_sync,
                                             dtype=np.int32)
                group_assignment = group_assignment.reshape(
                    [-1, num_replicas_per_group])
                group_assignment = group_assignment.tolist()
                num_replicas_per_group = tf.cast(num_replicas_per_group,
                                                 tf.float32)

        # This only supports NHWC format.
        if self.ensemble_size > 1:
            height = tf.shape(x)[1]
            width = tf.shape(x)[2]
            input_channels = tf.shape(x)[3]
            x = tf.reshape(
                x, [self.ensemble_size, -1, height, width, input_channels])
            mean = tf.reduce_mean(x, axis=[1, 2,
                                           3])  # [ensemble_size, channels]
            mean = tf.cast(mean, tf.float32)

            # Var[x] = E[x^2] - E[x]^2
            mean_sq = tf.reduce_mean(tf.square(x), axis=[1, 2, 3])
            mean_sq = tf.cast(mean_sq, tf.float32)
            if replica_context is not None:
                mean = tf1.tpu.cross_replica_sum(mean, group_assignment)
                mean = mean / num_replicas_per_group
                mean_sq = tf1.tpu.cross_replica_sum(mean_sq, group_assignment)
                mean_sq = mean_sq / num_replicas_per_group
            variance = mean_sq - tf.square(mean)
        else:
            mean = tf.reduce_mean(x, axis=[0, 1, 2])
            mean = tf.cast(mean, tf.float32)

            mean_sq = tf.reduce_mean(tf.square(x), axis=[0, 1, 2])
            mean_sq = tf.cast(mean_sq, tf.float32)
            if replica_context is not None:
                mean = tf1.tpu.cross_replica_sum(mean, group_assignment)
                mean = mean / num_replicas_per_group
                mean_sq = tf1.tpu.cross_replica_sum(mean_sq, group_assignment)
                mean_sq = mean_sq / num_replicas_per_group
            variance = mean_sq - tf.square(mean)

        def _assign(moving, normal):
            decay = tf.cast(1. - self.momentum, tf.float32)
            diff = tf.cast(moving, tf.float32) - tf.cast(normal, tf.float32)
            return moving.assign_sub(decay * diff)

        self.add_update(_assign(self.moving_mean, mean))
        self.add_update(_assign(self.moving_variance, variance))

        mean = tf.cast(mean, x.dtype)
        variance = tf.cast(variance, x.dtype)

        return mean, variance
Example #26
0
 def testMALAIsCalibrated(self):
     mala = tfp.mcmc.MetropolisAdjustedLangevinAlgorithm(
         target_log_prob_fn=lambda x: -tf.square(x) / 2.,
         step_size=0.1,
     )
     self.assertTrue(mala.is_calibrated)
Example #27
0
    def estimate_reward_ci(self,
                           dataset: dataset_lib.OffpolicyDataset,
                           target_policy: tf_policy.TFPolicy,
                           episode_limit: Optional[int] = None,
                           num_grid: Optional[int] = 100,
                           eps: Optional[float] = 1e-6,
                           num_bootstraps: Optional[int] = 10000,
                           num_bootstrap_samples: Optional[int] = 10000):
        """Estimate the confidence interval of reward."""
        is_weighted_reward_samples = self.get_is_weighted_reward_samples(
            dataset, target_policy, episode_limit)
        episodes, valid_steps = dataset.get_all_episodes(limit=episode_limit)
        num_episodes = tf.shape(valid_steps)[0]
        max_abs_reward = tf.reduce_max(
            tf.where(valid_steps, tf.abs(self._reward_fn(episodes)), 0.))

        # mean estimate
        center = self.estimate_average_reward(dataset, target_policy)
        delta_tail_half = self._delta_tail / 2.0
        num_episodes_float = tf.cast(num_episodes, tf.float32)

        if self._ci_method == 'CH':  # Chernoff-Hoeffding
            width = max_abs_reward * tf.math.sqrt(
                tf.math.log(1.0 / delta_tail_half) / num_episodes_float)
            lb = center - width
            ub = center + width
        elif self._ci_method == 'BE':  # Empirical Bernstein
            constant_term = 7 * max_abs_reward * tf.math.log(
                2.0 / delta_tail_half) / (3 * (num_episodes_float - 1))
            is_weighted_reward_samples_2d = tf.reshape(
                is_weighted_reward_samples, [-1, 1])
            variance_term = tf.reduce_sum(
                tf.square(
                    tf.tile(is_weighted_reward_samples_2d, [1, num_episodes]) -
                    is_weighted_reward_samples_2d))

            variance_term *= tf.math.log(
                2.0 / delta_tail_half) / (num_episodes_float - 1)
            width = constant_term + tf.math.sqrt(
                variance_term) / num_episodes_float
            lb = center - width
            ub = center + width
        elif self._ci_method == 'C-BE':  # Clipped empirical Bernstein
            # need to learn c
            def compute_center_width(c_const):
                """Compute the center and width of CI."""
                c_vec = c_const * tf.ones_like(is_weighted_reward_samples)
                c_is_weighted_reward_samples = tf.minimum(
                    is_weighted_reward_samples, c_vec) / c_vec
                c_is_weighted_reward_samples_2d = tf.reshape(
                    c_is_weighted_reward_samples, [-1, 1])
                constant_term = 7 * num_episodes_float * tf.math.log(
                    2.0 / delta_tail_half) / (3 * (num_episodes_float - 1))

                variance_term = tf.reduce_sum(
                    tf.square(
                        tf.tile(c_is_weighted_reward_samples_2d,
                                [1, num_episodes]) -
                        c_is_weighted_reward_samples_2d))
                variance_term *= tf.math.log(
                    2.0 / delta_tail_half) / (num_episodes_float - 1)

                width = (constant_term + tf.math.sqrt(variance_term)
                         ) / tf.reduce_sum(1.0 / c_vec)
                center = tf.reduce_sum(
                    c_is_weighted_reward_samples) / tf.reduce_sum(1.0 / c_vec)
                return center, width

            def compute_bdd(c_const):
                center, width = compute_center_width(c_const)
                return center - width, center + width

            def compute_obj(c_const, obj='width'):
                center, width = compute_center_width(c_const)
                if obj == 'lb':
                    return center - width
                elif obj == 'ub':  # minimize ub
                    return -(center + width)
                elif obj == 'width':
                    return width
                elif obj == 'lb_ub':
                    return -2 * width
                else:
                    ValueError('Objective is not implemented')

            c_grid = tf.linspace(eps, max_abs_reward, num_grid)
            objs = tf.map_fn(compute_obj, c_grid, dtype=tf.float32)

            star_index = tf.argmax(objs)
            c_star = tf.gather(c_grid, star_index)

            lb, ub = compute_bdd(c_star)

        elif self._ci_method == 'TT':  # Student-t test
            # Two-tailed confidence intervals
            t_statistic_quantile = stats.t.ppf(1 - delta_tail_half,
                                               num_episodes_float - 1)
            std_term = tf.math.sqrt(
                tf.reduce_sum(tf.square(is_weighted_reward_samples - center)) /
                (num_episodes_float - 1))
            width = t_statistic_quantile * std_term / tf.math.sqrt(
                num_episodes_float)
            lb = center - width
            ub = center + width
        elif self._ci_method == 'BCa':  # Bootstrap
            # see references
            # https://faculty.washington.edu/heagerty/Courses/b572/public/GregImholte-1.pdf
            # http://users.stat.umn.edu/~helwig/notes/bootci-Notes.pdf
            gaussian_rv = tfp.distributions.Normal(loc=0, scale=1)

            def _compute_bootstrap_lb_ub(reward_samples):
                """Compute Efron's bootstrap lb."""
                sample_mean = tf.reduce_mean(reward_samples)
                # Step 1, sample with replacement and compute subsampled mean
                uniform_log_prob = tf.tile(
                    tf.expand_dims(tf.zeros(num_episodes), 0),
                    [num_bootstraps, 1])
                ind = tf.random.categorical(uniform_log_prob,
                                            num_bootstrap_samples)
                bootstrap_subsamples = tf.gather(reward_samples, ind)
                subsample_means = tf.reduce_mean(bootstrap_subsamples, axis=1)

                # Step 2, sort subsample means, compute y, z_0, and a
                sorted_subsample_means = tf.sort(subsample_means,
                                                 axis=0,
                                                 direction='ASCENDING')

                # bias factor
                z_0 = gaussian_rv.quantile(
                    tf.reduce_sum(
                        tf.cast(
                            tf.greater(sample_mean, sorted_subsample_means),
                            tf.float32)) / float(num_bootstraps))
                # y is the leave-one-out, jackknife sample mean
                mask_matrix = tf.ones([num_episodes, num_episodes
                                       ]) - tf.eye(num_episodes)
                leave_one_out_subsample_sums = tf.einsum(
                    'j,jk->k', reward_samples, mask_matrix)
                ys = leave_one_out_subsample_sums / (num_episodes_float - 1)

                # average of jackknife estimate
                y_bar = tf.reduce_mean(ys)

                # acceleration factor
                d_ys = y_bar - ys
                a = tf.reduce_sum(tf.pow(d_ys, 3.0)) / tf.maximum(
                    eps, 6.0 * tf.pow(tf.reduce_sum(tf.pow(d_ys, 2.0)), 1.5))

                # Step 3, compute z_scores for lb and ub
                z_score_delta_tail = gaussian_rv.quantile(delta_tail_half)
                z_score_1_delta_tail = gaussian_rv.quantile(1.0 -
                                                            delta_tail_half)

                z_lb = z_0 + (z_score_delta_tail + z_0) / tf.maximum(
                    eps, 1 - a * (z_score_delta_tail + z_0))
                z_ub = z_0 + (z_score_1_delta_tail + z_0) / tf.maximum(
                    eps, 1 - a * (z_score_1_delta_tail + z_0))

                # Step 4, compute corresponding quantiles and get bootstrap intervals
                lb_index = tf.cast(
                    tf.maximum(
                        tf.minimum(
                            tf.floor(num_bootstraps * gaussian_rv.cdf(z_lb)),
                            num_bootstraps - 1), 1), tf.int64)
                ub_index = tf.cast(
                    tf.maximum(
                        tf.minimum(
                            tf.floor(num_bootstraps * gaussian_rv.cdf(z_ub)),
                            num_bootstraps - 1), 1), tf.int64)

                lb = tf.gather(sorted_subsample_means, lb_index)
                ub = tf.gather(sorted_subsample_means, ub_index)

                return lb, ub

            lb, ub = _compute_bootstrap_lb_ub(is_weighted_reward_samples)
        else:
            ValueError('Confidence interval is not implemented!')
        return [lb, ub]
Example #28
0
 def testUncalibratedLangevinIsNotCalibrated(self):
     uncal_langevin = tfp.mcmc.UncalibratedLangevin(
         target_log_prob_fn=lambda x: -tf.square(x) / 2.,
         step_size=0.1,
     )
     self.assertFalse(uncal_langevin.is_calibrated)
 def testRWMIsCalibrated(self):
     rwm = tfp.mcmc.RandomWalkMetropolis(
         target_log_prob_fn=lambda x: -tf.square(x) / 2., )
     self.assertTrue(rwm.is_calibrated)
Example #30
0
def main():
    #download google pre-trained neural network
    local_zip_file = 'inception5h.zip'
    if not os.path.exists(local_zip_file):
        #download
        model_url = urllib.request.urlopen(url)
        with open(local_zip_file, 'wb') as output:
            output.write(model_url.read())

        #extract
        with zipfile.ZipFile(local_zip_file, 'r') as zip_ref:
            zip_ref.extractall(data_dir)
    model_fn = 'tensorflow_inseption_graph.pb'

    #Creating tf session and loading the model
    graph = tf.Graph()
    sess = tfc.InteractiveSession(graph=graph)
    with tfc.gfile.FastGFile((local_zip_file), 'rb') as f:
        graph_def = tf.io.gfile.GFile()
        graph_def.ParseFromString(f.read())
    t_input = tf.placeholder(np.float32, name='input')  #define input tensor
    imagenet_mean = 117.0
    t_preprocessed = tf.expand_dims(t_input - imagenet_mean, 0)
    tf.import_graph_def(graph_def, {'input': t_preprocessed})

    layers = [
        op.name for op in graph.get_operations()
        if op.type == 'Cony2D' and 'import/' in op.name
    ]
    feature_nums = [
        int(graph.get_tensor_by_name(model_name + ':0').get_shape()[-1])
        for name in layers
    ]

    print('Number of layers: ', len(layers))
    print('Total numbers of feature channels:', sum(feature_nums))

    def render_deepdream(t_obj,
                         img0=img_noise,
                         iter_n=10,
                         step=1.5,
                         octave_n=4,
                         octave_scale=1.4):
        t_score = tf.reduce_mean(t_obj)  #defining optimization objective
        t_grad = tf.gradients(t_score, t_input)[0]

        #split the image into a number of octaves
        img = img0
        octaves = []
        for _ in range(octave_n - 1):
            hw = img.shape[:2]
            lo = resize(img, np.int32(np.float32(hw) / octave_scale))
            hi = img - resize(low, hw)
            img = lo
            octaves.append(hi)

        #generate details octave by octave
        for octave in range(octave_n):
            if octave > 0:
                hi = octaves[-octave]
                img = resize(img, hi.shape[:2]) + hi
            for _ in range(iter_n):
                g = calc_grad_tiled(img, t_grad)
                img += g * (step / (np.abs(g).mean() + 1e-7))
            #output deep dreamed image
            showarray(img / 255.0)

    #Pick a layer to enchance my image
    layer = 'mixed4d_3x3_bottleneck_pre_relu'
    channel = 139

    img0 = PIL.Image.open('image.jpg')
    img0 = np.float32(img0)

    #Apply gradient ascent to the layer
    render_deepdream(tf.square(T('mixed4c')), img0)