def _update_trajectory_grad(previous_kernel_results, previous_state,
                            proposed_state, proposed_velocity,
                            trajectory_jitter, accept_prob, step_size,
                            criterion_fn, max_leapfrog_steps):
  """Updates the trajectory length."""
  # Compute criterion grads.
  def leapfrog_action(dt):
    # This represents the effect on the criterion value as the state follows the
    # proposed velocity. This implicitly assumes an identity mass matrix.
    return criterion_fn(
        previous_state,
        tf.nest.map_structure(
            lambda x, v:  # pylint: disable=g-long-lambda
            (x + mcmc_util.left_justified_expand_dims_like(dt, v) * v),
            proposed_state,
            proposed_velocity),
        accept_prob)

  criterion, trajectory_grad = gradient.value_and_gradient(
      leapfrog_action, tf.zeros_like(accept_prob))
  trajectory_grad *= trajectory_jitter

  # Weight by acceptance probability.
  trajectory_grad = tf.where(accept_prob > 1e-4, trajectory_grad, 0.)
  trajectory_grad = tf.where(
      tf.math.is_finite(trajectory_grad), trajectory_grad, 0.)
  trajectory_grad = (
      tf.reduce_sum(trajectory_grad * accept_prob) /
      tf.reduce_sum(accept_prob + 1e-20))

  # Compute Adam/RMSProp step size.
  dtype = previous_kernel_results.adaptation_rate.dtype
  iteration_f = tf.cast(previous_kernel_results.step, dtype) + 1.
  msg_adaptation_rate = 0.05
  new_averaged_sq_grad = (
      (1 - msg_adaptation_rate) * previous_kernel_results.averaged_sq_grad +
      msg_adaptation_rate * trajectory_grad**2)
  adjusted_averaged_sq_grad = new_averaged_sq_grad / (
      1. - (1 - msg_adaptation_rate)**iteration_f)
  trajectory_step_size = (
      previous_kernel_results.adaptation_rate /
      tf.sqrt(adjusted_averaged_sq_grad + 1e-20))

  # Apply the gradient. Clip absolute value to ~log(2)/2.
  log_update = tf.clip_by_value(trajectory_step_size * trajectory_grad, -0.35,
                                0.35)
  new_max_trajectory_length = previous_kernel_results.max_trajectory_length * tf.exp(
      log_update)

  # Iterate averaging.
  average_weight = iteration_f**(-0.5)
  new_averaged_max_trajectory_length = tf.exp(
      average_weight * tf.math.log(new_max_trajectory_length) +
      (1 - average_weight) *
      tf.math.log(1e-10 +
                  previous_kernel_results.averaged_max_trajectory_length))

  # Clip the maximum trajectory length.
  new_max_trajectory_length = _clip_max_trajectory_length(
      new_max_trajectory_length, step_size,
      previous_kernel_results.adaptation_rate, max_leapfrog_steps)

  return previous_kernel_results._replace(
      criterion=criterion,
      max_trajectory_length=new_max_trajectory_length,
      averaged_sq_grad=new_averaged_sq_grad,
      averaged_max_trajectory_length=new_averaged_max_trajectory_length)
Ejemplo n.º 2
0
def quantized_distributions(draw,
                            batch_shape=None,
                            event_dim=None,
                            enable_vars=False,
                            eligibility_filter=lambda name: True,
                            validate_args=True):
    """Strategy for drawing `QuantizedDistribution`s.

  The underlying distribution is drawn from the `base_distributions` strategy.

  Args:
    draw: Hypothesis strategy sampler supplied by `@hps.composite`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      `QuantizedDistribution`. Hypothesis will pick a `batch_shape` if omitted.
    event_dim: Optional Python int giving the size of each of the underlying
      distribution's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all Tensors, never Variables or DeferredTensor.
    eligibility_filter: Optional Python callable.  Blacklists some Distribution
      class names so they will not be drawn.
    validate_args: Python `bool`; whether to enable runtime assertions.

  Returns:
    dists: A strategy for drawing `QuantizedDistribution`s with the specified
      `batch_shape` (or an arbitrary one if omitted).
  """

    if batch_shape is None:
        batch_shape = draw(tfp_hps.shapes())

    low_quantile = draw(
        hps.one_of(hps.just(None), hps.floats(min_value=0.01, max_value=0.7)))
    high_quantile = draw(
        hps.one_of(hps.just(None), hps.floats(min_value=0.3, max_value=.99)))

    def ok(name):
        return eligibility_filter(name) and name in QUANTIZED_BASE_DISTS

    underlyings = base_distributions(
        batch_shape=batch_shape,
        event_dim=event_dim,
        enable_vars=enable_vars,
        eligibility_filter=ok,
    )
    underlying = draw(underlyings)

    if high_quantile is not None:
        high_quantile = tf.convert_to_tensor(high_quantile,
                                             dtype=underlying.dtype)
    if low_quantile is not None:
        low_quantile = tf.convert_to_tensor(low_quantile,
                                            dtype=underlying.dtype)
        if high_quantile is not None:
            high_quantile = ensure_high_gt_low(low_quantile, high_quantile)

    hp.note('Drawing QuantizedDistribution with underlying distribution'
            ' {}'.format(underlying))

    try:
        low = None if low_quantile is None else underlying.quantile(
            low_quantile)
        high = None if high_quantile is None else underlying.quantile(
            high_quantile)
    except NotImplementedError:
        # The following code makes ReproducibilityTest flaky in graph mode (but not
        # eager). Failures are due either to partial mismatch in the samples in
        # ReproducibilityTest or to `low` and/or `high` being NaN. For now, to avoid
        # this, we set `low` and `high` to `None` for distributions not implementing
        # `quantile`.

        # seed = test_util.test_seed(hardcoded_seed=123)
        # low = (None if low_quantile is None
        #        else underlying.sample(low_quantile.shape, seed=seed))
        # high = (None if high_quantile is None else
        #         underlying.sample(high_quantile.shape, seed=seed))
        low = None
        high = None

    # Ensure that `low` and `high` are ints contained in distribution support
    # and span at least a few bins.
    if high is not None:
        high = tf.clip_by_value(high, -2**23, 2**23)
        high = tf.math.ceil(high + 5.)

    if low is not None:
        low = tf.clip_by_value(low, -2**23, 2**23)
        low = tf.math.ceil(low)

    result_dist = tfd.QuantizedDistribution(distribution=underlying,
                                            low=low,
                                            high=high,
                                            validate_args=validate_args)

    return result_dist
Ejemplo n.º 3
0
    def _sample_channels(self,
                         component_logits,
                         locs,
                         scales,
                         coeffs=None,
                         seed=None):
        """Sample a single pixel-iteration and apply channel conditioning.

    Args:
      component_logits: 4D `Tensor` of logits for the Categorical distribution
        over Quantized Logistic mixture components. Dimensions are `[batch_size,
        height, width, num_logistic_mix]`.
      locs: 4D `Tensor` of location parameters for the Quantized Logistic
        mixture components. Dimensions are `[batch_size, height, width,
        num_logistic_mix, num_channels]`.
      scales: 4D `Tensor` of location parameters for the Quantized Logistic
        mixture components. Dimensions are `[batch_size, height, width,
        num_logistic_mix, num_channels]`.
      coeffs: 4D `Tensor` of coefficients for the linear dependence among color
        channels, or `None` if there is only one channel. Dimensions are
        `[batch_size, height, width, num_logistic_mix, num_coeffs]`, where
        `num_coeffs = num_channels * (num_channels - 1) // 2`.
      seed: `int`, random seed.

    Returns:
      samples: 4D `Tensor` of sampled image data with autoregression among
        channels. Dimensions are `[batch_size, height, width, num_channels]`.
    """
        num_channels = self.event_shape[-1]

        # sample mixture components once for the entire pixel
        component_dist = categorical.Categorical(logits=component_logits)
        mask = tf.one_hot(indices=component_dist.sample(seed=seed),
                          depth=self._num_logistic_mix)
        mask = tf.cast(mask[..., tf.newaxis], self.dtype)

        # apply mixture component mask and separate out RGB parameters
        masked_locs = tf.reduce_sum(locs * mask, axis=-2)
        loc_tensors = tf.split(masked_locs, num_channels, axis=-1)
        masked_scales = tf.reduce_sum(scales * mask, axis=-2)
        scale_tensors = tf.split(masked_scales, num_channels, axis=-1)

        if coeffs is not None:
            num_coeffs = num_channels * (num_channels - 1) // 2
            masked_coeffs = tf.reduce_sum(coeffs * mask, axis=-2)
            coef_tensors = tf.split(masked_coeffs, num_coeffs, axis=-1)

        channel_samples = []
        coef_count = 0
        for i in range(num_channels):
            loc = loc_tensors[i]
            for c in channel_samples:
                loc += c * coef_tensors[coef_count]
                coef_count += 1

            logistic_samp = logistic.Logistic(
                loc=loc, scale=scale_tensors[i]).sample(seed=seed)
            logistic_samp = tf.clip_by_value(logistic_samp, -1., 1.)
            channel_samples.append(logistic_samp)

        return tf.concat(channel_samples, axis=-1)
Ejemplo n.º 4
0
    def GP_train(self, amp_init=None, len_init=None, num_iters=None):

        def build_gp(amplitude, length_scale):

          kernel = tfk.ExponentiatedQuadratic(amplitude, length_scale)

          return tfd.GaussianProcess(
              kernel=kernel,
              index_points=self.obs_ind) # ,jitter=1e-03

        gp_joint_model = tfd.JointDistributionNamed({
            'amplitude': tfd.LogNormal(loc=amp_init, scale=np.float64(1.)),
            'length_scale': tfd.LogNormal(loc=len_init, scale=np.float64(1.)),
            'observations': build_gp,
        })

        amplitude_ = tf.Variable(initial_value=amp_init, name='amplitude_', dtype=np.float64,constraint=lambda z: tf.clip_by_value(z, 1e-4, 10000)) # lambda z: tf.clip_by_value(z, 0, 10000)
        length_scale_ = tf.Variable(initial_value=len_init, name='length_scale_', dtype=np.float64,constraint=lambda z: tf.clip_by_value(z, 1e-4, 10000))

        @tf.function(autograph=False, experimental_compile=False)
        def target_log_prob(amplitude, length_scale):
          return gp_joint_model.log_prob({
              'amplitude': amplitude,
              'length_scale': length_scale,
              'observations': self.obs
          })

        optimizer = tf.optimizers.Adam(learning_rate=.01)

        for i in range(num_iters):
            with tf.GradientTape() as tape:
                loss = -target_log_prob(amplitude_, length_scale_)
            grads = tape.gradient(loss, [amplitude_, length_scale_])
            optimizer.apply_gradients(zip(grads, [amplitude_, length_scale_]))

        print('Trained parameters:')
        print('amplitude: {}'.format(amplitude_.numpy()))
        print('length_scale: {}'.format(length_scale_.numpy()))

        return amplitude_, length_scale_
Ejemplo n.º 5
0
 def get_logit_alpha():
     a = tf.clip_by_value(value / 4., 0., 1.)
     logit_alpha = tf.math.log(a / (1. - a))
     return logit_alpha
Ejemplo n.º 6
0
def _batch_interp_with_gather_nd(x, x_ref_min, x_ref_max, y_ref, nd,
                                 fill_value, batch_dims):
    """N-D interpolation that works with leading batch dims."""
    dtype = x.dtype

    # In this function,
    # x.shape = [A1, ..., An, D, nd], where n = batch_dims
    # and
    # y_ref.shape = [A1, ..., An, C1, C2,..., Cnd, B1,...,BM]
    # y_ref[A1, ..., An, i1,...,ind] is a shape [B1,...,BM] Tensor with the value
    # at index [i1,...,ind] in the interpolation table.
    #  and x_ref_max have shapes [A1, ..., An, nd].

    # ny[k] is number of y reference points in interp dim k.
    ny = tf.cast(tf.shape(y_ref)[batch_dims:batch_dims + nd], dtype)

    # Map [x_ref_min, x_ref_max] to [0, ny - 1].
    # This is the (fractional) index of x.
    # x_idx_unclipped[A1, ..., An, d, k] is the fractional index into dim k of
    # interpolation table for the dth x value.
    x_ref_min_expanded = tf.expand_dims(x_ref_min, axis=-2)
    x_ref_max_expanded = tf.expand_dims(x_ref_max, axis=-2)
    x_idx_unclipped = (ny - 1) * (x - x_ref_min_expanded) / (
        x_ref_max_expanded - x_ref_min_expanded)

    # Wherever x is NaN, x_idx_unclipped will be NaN as well.
    # Keep track of the nan indices here (so we can impute NaN later).
    # Also eliminate any NaN indices, since there is not NaN in 32bit.
    nan_idx = tf.math.is_nan(x_idx_unclipped)
    x_idx_unclipped = tf.where(nan_idx, 0., x_idx_unclipped)

    # x_idx.shape = [A1, ..., An, D, nd]
    x_idx = tf.clip_by_value(x_idx_unclipped, tf.zeros((), dtype=dtype),
                             ny - 1)

    # Get the index above and below x_idx.
    # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx),
    # however, this results in idx_below == idx_above whenever x is on a grid.
    # This in turn results in y_ref_below == y_ref_above, and then the gradient
    # at this point is zero.  So here we 'jitter' one of idx_below, idx_above,
    # so that they are at different values.  This jittering does not affect the
    # interpolated value, but does make the gradient nonzero (unless of course
    # the y_ref values are the same).
    idx_below = tf.floor(x_idx)
    idx_above = tf.minimum(idx_below + 1, ny - 1)
    idx_below = tf.maximum(idx_above - 1, 0)

    # These are the values of y_ref corresponding to above/below indices.
    # idx_below_int32.shape = x.shape[:-1] + [nd]
    idx_below_int32 = tf.cast(idx_below, dtype=tf.int32)
    idx_above_int32 = tf.cast(idx_above, dtype=tf.int32)

    # idx_below_list is a length nd list of shape x.shape[:-1] int32 tensors.
    idx_below_list = tf.unstack(idx_below_int32, axis=-1)
    idx_above_list = tf.unstack(idx_above_int32, axis=-1)

    # Use t to get a convex combination of the below/above values.
    # t.shape = [A1, ..., An, D, nd]
    t = x_idx - idx_below

    # x, and tensors shaped like x, need to be added to, and selected with
    # (using tf.where) the output y.  This requires appending singletons.
    def _expand_x_fn(tensor):
        # Reshape tensor to tensor.shape + [1] * M.
        extended_shape = tf.concat([
            tf.shape(tensor),
            tf.ones_like(tf.shape(y_ref)[batch_dims + nd:])
        ],
                                   axis=0)
        return tf.reshape(tensor, extended_shape)

    # Now, t.shape = [A1, ..., An, D, nd] + [1] * (rank(y_ref) - nd - batch_dims)
    t = _expand_x_fn(t)
    s = 1 - t

    # Re-insert NaN wherever x was NaN.
    nan_idx = _expand_x_fn(nan_idx)
    t = tf.where(nan_idx, tf.constant(np.nan, dtype), t)

    terms = []
    # Our work above has located x's fractional index inside a cube of above/below
    # indices. The distance to the below indices is t, and to the above indices
    # is s.
    # Drawing lines from x to the cube walls, we get 2**nd smaller cubes. Each
    # term in the result is a product of a reference point, gathered from y_ref,
    # multiplied by a volume.  The volume is that of the cube opposite to the
    # reference point.  E.g. if the reference point is below x in every axis, the
    # volume is that of the cube with corner above x in every axis, s[0]*...*s[nd]
    # We could probably do this with one massive gather, but that would be very
    # unreadable and un-debuggable.  It also would create a large Tensor.
    for zero_ones_list in _binary_count(nd):
        gather_from_y_ref_idx = []
        opposite_volume_t_idx = []
        opposite_volume_s_idx = []
        for k, zero_or_one in enumerate(zero_ones_list):
            if zero_or_one == 0:
                # If the kth iterate has zero_or_one = 0,
                # Will gather from the 'below' reference point along axis k.
                gather_from_y_ref_idx.append(idx_below_list[k])
                # Now append the index to gather for computing opposite_volume.
                # This could be done by initializing opposite_volume to 1, then here:
                #  opposite_volume *= tf.gather(s, indices=k, axis=tf.rank(x) - 1)
                # but that puts a gather in the 'inner loop.'  Better to append the
                # index and do one larger gather down below.
                opposite_volume_s_idx.append(k)
            else:
                gather_from_y_ref_idx.append(idx_above_list[k])
                # Append an index to gather, having the same effect as
                #   opposite_volume *= tf.gather(t, indices=k, axis=tf.rank(x) - 1)
                opposite_volume_t_idx.append(k)

        # Compute opposite_volume (volume of cube opposite the ref point):
        # Recall t.shape = s.shape = [D, nd] + [1, ..., 1]
        # Gather from t and s along the 'nd' axis, which is rank(x) - 1.
        ov_axis = tf.rank(x) - 1
        opposite_volume = (tf.reduce_prod(
            tf.gather(t,
                      indices=tf.cast(opposite_volume_t_idx, dtype=tf.int32),
                      axis=ov_axis),
            axis=ov_axis) * tf.reduce_prod(tf.gather(
                s,
                indices=tf.cast(opposite_volume_s_idx, dtype=tf.int32),
                axis=ov_axis),
                                           axis=ov_axis))  # pyformat: disable

        y_ref_pt = tf.gather_nd(y_ref,
                                tf.stack(gather_from_y_ref_idx, axis=-1),
                                batch_dims=batch_dims)

        terms.append(y_ref_pt * opposite_volume)

    y = tf.math.add_n(terms)

    if tf.debugging.is_numeric_tensor(fill_value):
        # Recall x_idx_unclipped.shape = [D, nd],
        # so here we check if it was out of bounds in any of the nd dims.
        # Thus, oob_idx.shape = [D].
        oob_idx = tf.reduce_any(
            (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1), axis=-1)

        # Now, y.shape = [D, B1,...,BM], so we'll have to broadcast oob_idx.

        oob_idx = _expand_x_fn(oob_idx)  # Shape [D, 1,...,1]
        oob_idx |= tf.fill(tf.shape(y), False)
        y = tf.where(oob_idx, fill_value, y)
    return y
Ejemplo n.º 7
0
    def __init__(self,
                 state_dim,
                 action_dim,
                 log_interval,
                 actor_lr=1e-3,
                 critic_lr=1e-3,
                 alpha_init=1.0,
                 learn_alpha=True,
                 algae_alpha=1.0,
                 use_dqn=True,
                 use_init_states=True,
                 exponent=2.0):
        """Creates networks.

    Args:
      state_dim: State size.
      action_dim: Action size.
      log_interval: Log losses every N steps.
      actor_lr: Actor learning rate.
      critic_lr: Critic learning rate.
      alpha_init: Initial temperature value for causal entropy regularization.
      learn_alpha: Whether to learn alpha or not.
      algae_alpha: Algae regularization weight.
      use_dqn: Whether to use double networks for target value.
      use_init_states: Whether to use initial states in objective.
      exponent: Exponent p of function f(x) = |x|^p / p.
    """
        self.actor = Actor(state_dim, action_dim)
        self.actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)
        self.avg_actor_loss = tf.keras.metrics.Mean('actor_loss',
                                                    dtype=tf.float32)
        self.avg_alpha_loss = tf.keras.metrics.Mean('alpha_loss',
                                                    dtype=tf.float32)
        self.avg_actor_entropy = tf.keras.metrics.Mean('actor_entropy',
                                                       dtype=tf.float32)
        self.avg_alpha = tf.keras.metrics.Mean('alpha', dtype=tf.float32)
        self.avg_lambda = tf.keras.metrics.Mean('lambda', dtype=tf.float32)
        self.use_init_states = use_init_states

        if use_dqn:
            self.critic = DoubleCritic(state_dim, action_dim)
            self.critic_target = DoubleCritic(state_dim, action_dim)
        else:
            self.critic = Critic(state_dim, action_dim)
            self.critic_target = Critic(state_dim, action_dim)
        soft_update(self.critic, self.critic_target, tau=1.0)
        self._lambda = tf.Variable(0.0, trainable=True)
        self.critic_optimizer = tf.keras.optimizers.Adam(
            learning_rate=critic_lr)
        self.avg_critic_loss = tf.keras.metrics.Mean('critic_loss',
                                                     dtype=tf.float32)

        self.log_alpha = tf.Variable(tf.math.log(alpha_init), trainable=True)
        self.learn_alpha = learn_alpha
        self.alpha_optimizer = tf.keras.optimizers.Adam()

        self.log_interval = log_interval

        self.algae_alpha = algae_alpha
        self.use_dqn = use_dqn
        self.exponent = exponent
        if self.exponent <= 1:
            raise ValueError(
                'Exponent must be greather than 1, but received %f.' %
                self.exponent)
        self.f = lambda resid: tf.pow(tf.abs(resid), self.exponent
                                      ) / self.exponent
        clip_resid = lambda resid: tf.clip_by_value(resid, 0.0, 1e6)
        self.fgrad = lambda resid: tf.pow(clip_resid(resid), self.exponent - 1)
Ejemplo n.º 8
0
  def call(self, inputs):
    if (not isinstance(inputs, random_variable.RandomVariable) and
        not isinstance(self.kernel, random_variable.RandomVariable) and
        not isinstance(self.bias, random_variable.RandomVariable)):
      return super(DenseDVI, self).call(inputs)
    self.call_weights()
    inputs_mean, inputs_variance, inputs_covariance = get_moments(inputs)
    kernel_mean, kernel_variance, _ = get_moments(self.kernel)
    if self.use_bias:
      bias_mean, _, bias_covariance = get_moments(self.bias)

    # E[outputs] = E[inputs] * E[kernel] + E[bias]
    mean = tf.tensordot(inputs_mean, kernel_mean, [[-1], [0]])
    if self.use_bias:
      mean = tf.nn.bias_add(mean, bias_mean)

    # Cov = E[inputs**2] Cov(kernel) + E[W]^T Cov(inputs) E[W] + Cov(bias)
    # For first term, assume Cov(kernel) = 0 on off-diagonals so we only
    # compute diagonal term.
    covariance_diag = tf.tensordot(inputs_variance + inputs_mean**2,
                                   kernel_variance, [[-1], [0]])
    # Compute quadratic form E[W]^T Cov E[W] from right-to-left. First is
    #  [..., features, features], [features, units] -> [..., features, units].
    cov_w = tf.tensordot(inputs_covariance, kernel_mean, [[-1], [0]])
    # Next is [..., features, units], [features, units] -> [..., units, units].
    w_cov_w = tf.tensordot(cov_w, kernel_mean, [[-2], [0]])
    covariance = w_cov_w
    if self.use_bias:
      covariance += bias_covariance
    covariance = tf.linalg.set_diag(
        covariance, tf.linalg.diag_part(covariance) + covariance_diag)

    if self.activation in (tf.keras.activations.relu, tf.nn.relu):
      # Compute activation's moments with variable names from Wu et al. (2018).
      variance = tf.linalg.diag_part(covariance)
      scale = tf.sqrt(variance)
      mu = mean / (scale + tf.keras.backend.epsilon())
      mean = scale * soft_relu(mu)

      pairwise_variances = (tf.expand_dims(variance, -1) *
                            tf.expand_dims(variance, -2))  # [..., units, units]
      rho = covariance / tf.sqrt(pairwise_variances +
                                 tf.keras.backend.epsilon())
      rho = tf.clip_by_value(rho,
                             -1. / (1. + tf.keras.backend.epsilon()),
                             1. / (1. + tf.keras.backend.epsilon()))
      s = covariance / (rho + tf.keras.backend.epsilon())
      mu1 = tf.expand_dims(mu, -1)  # [..., units, 1]
      mu2 = tf.linalg.matrix_transpose(mu1)  # [..., 1, units]
      a = (soft_relu(mu1) * soft_relu(mu2) +
           rho * tfp.distributions.Normal(0., 1.).cdf(mu1) *
           tfp.distributions.Normal(0., 1.).cdf(mu2))
      gh = tf.asinh(rho)
      bar_rho = tf.sqrt(1. - rho**2)
      gr = gh + rho / (1. + bar_rho)
      # Include numerically stable versions of gr and rho when multiplying or
      # dividing them. The sign of gr*rho and rho/gr is always positive.
      safe_gr = tf.abs(gr) + 0.5 * tf.keras.backend.epsilon()
      safe_rho = tf.abs(rho) + tf.keras.backend.epsilon()
      exp_negative_q = gr / (2. * math.pi) * tf.exp(
          -safe_rho / (2. * safe_gr * (1 + bar_rho)) +
          (gh - rho) / (safe_gr * safe_rho) * mu1 * mu2)
      covariance = s * (a + exp_negative_q)
    elif self.activation not in (tf.keras.activations.linear, None):
      raise NotImplementedError('Activation is {}. Deterministic variational '
                                'inference is only available if activation is '
                                'ReLU or None.'.format(self.activation))

    return generated_random_variables.MultivariateNormalFullCovariance(
        mean, covariance)
Ejemplo n.º 9
0
 def _mode(self):
     # mode = { loc:         for low <= loc <= high
     #          low: for loc < low
     #          high: for loc > high
     #        }
     return tf.clip_by_value(self.loc, self.low, self.high)
Ejemplo n.º 10
0
def color_jitter_rand(image,
                      brightness=0,
                      contrast=0,
                      saturation=0,
                      hue=0,
                      impl="simclrv2"):
    """Distorts the color of the image (jittering order is random).
    Args:
      image: The input image tensor.
      brightness: A float, specifying the brightness for color jitter.
      contrast: A float, specifying the contrast for color jitter.
      saturation: A float, specifying the saturation for color jitter.
      hue: A float, specifying the hue for color jitter.
      impl: 'simclrv1' or 'simclrv2'.  Whether to use simclrv1 or simclrv2's
          version of random brightness.
    Returns:
      The distorted image tensor.
    """
    with tf.name_scope("distort_color"):

        def apply_transform(i, x):
            """Apply the i-th transformation."""
            def brightness_foo():
                if brightness == 0:
                    return x
                else:
                    return random_brightness(x,
                                             max_delta=brightness,
                                             impl=impl)

            def contrast_foo():
                if contrast == 0:
                    return x
                else:
                    return tf.image.random_contrast(x,
                                                    lower=1 - contrast,
                                                    upper=1 + contrast)

            def saturation_foo():
                if saturation == 0:
                    return x
                else:
                    return tf.image.random_saturation(x,
                                                      lower=1 - saturation,
                                                      upper=1 + saturation)

            def hue_foo():
                if hue == 0:
                    return x
                else:
                    return tf.image.random_hue(x, max_delta=hue)

            x = tf.cond(
                tf.less(i, 2),
                lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo),
                lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo),
            )
            return x

        perm = tf.random.shuffle(tf.range(4))
        for i in range(4):
            image = apply_transform(perm[i], image)
            image = tf.clip_by_value(image, 0.0, 1.0)
        return image
Ejemplo n.º 11
0
 def _cdf(self, x):
     cdf_in_support = ((special_math.ndtr((x - self.loc) / self.scale) -
                        special_math.ndtr(self._standardized_low)) /
                       self._normalizer)
     return tf.clip_by_value(cdf_in_support, 0., 1.)
Ejemplo n.º 12
0
  def _step(self) -> Dict[str, tf.Tensor]:
    # Get data from replay (dropping extras if any). Note there is no
    # extra data here because we do not insert any into Reverb.
    sample = next(self._iterator)
    o_tm1, a_tm1, r_t, d_t, o_t = sample.data[:5]

    # Cast the additional discount to match the environment discount dtype.
    discount = tf.cast(self._discount, dtype=d_t.dtype)

    q_t = self._target_critic_network(o_t,
                                      self._policy_network(o_t))
    if not self._distributional and self._vmin is not None:
      q_t = tf.clip_by_value(q_t, self._vmin, self._vmax)
      logging.info('Clip target critic network output with [%f, %f]',
                   self._vmin, self._vmax)

    with tf.GradientTape() as tape:
      # Critic learning.
      q_tm1 = self._critic_network(o_tm1, a_tm1)

      # Critic loss.
      if self._distributional:
        critic_loss = losses.categorical(q_tm1, r_t, discount * d_t, q_t)
      else:
        # Squeeze into the shape expected by the td_learning implementation.
        q_tm1 = tf.squeeze(q_tm1, axis=-1)  # [B]
        q_t = tf.squeeze(q_t, axis=-1)  # [B]
        critic_loss = trfl.td_learning(q_tm1, r_t, discount * d_t, q_t).loss

      critic_loss = tf.reduce_mean(critic_loss, axis=[0])

    # Get trainable variables.
    critic_variables = self._critic_network.trainable_variables

    # Compute gradients.
    critic_gradients = tape.gradient(critic_loss, critic_variables)

    # Maybe clip gradients.
    if self._clipping:
      critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0]

    # Apply gradients.
    self._critic_optimizer.apply(critic_gradients, critic_variables)

    source_variables = self._critic_network.variables
    target_variables = self._target_critic_network.variables

    # Make online -> target network update ops.
    if tf.math.mod(self._num_steps, self._target_update_period) == 0:
      for src, dest in zip(source_variables, target_variables):
        dest.assign(src)

    if self._init_observations is not None:
      if tf.math.mod(self._num_steps, 100) == 0:
        # init_obs = tf.convert_to_tensor(self._init_observations, tf.float32)
        init_obs = tree.map_structure(tf.convert_to_tensor,
                                      self._init_observations)
        init_actions = self._policy_network(init_obs)
        init_critic = tf.reduce_mean(self._critic_mean(init_obs, init_actions))
      else:
        init_critic = tf.constant(0.)
    else:
      init_critic = tf.constant(0.)

    self._num_steps.assign_add(1)

    # Losses to track.
    return {
        'critic_loss': critic_loss,
        'q_s0': init_critic,
    }
def clip_sequence_value(tensor, lower_limit, upper_limit):
    """Clip values consistently across time with dim[0] as time."""
    return tf.clip_by_value(tensor, lower_limit, upper_limit)
Ejemplo n.º 14
0
def _update_confusion_matrix_variables_optimized(
        variables_to_update,
        y_true,
        y_pred,
        thresholds,
        multi_label=False,
        sample_weights=None,
        label_weights=None,
        thresholds_with_epsilon=False):
    """Update confusion matrix variables with memory efficient alternative.

  Note that the thresholds need to be evenly distributed within the list, eg,
  the diff between consecutive elements are the same.

  To compute TP/FP/TN/FN, we are measuring a binary classifier
    C(t) = (predictions >= t)
  at each threshold 't'. So we have
    TP(t) = sum( C(t) * true_labels )
    FP(t) = sum( C(t) * false_labels )

  But, computing C(t) requires computation for each t. To make it fast,
  observe that C(t) is a cumulative integral, and so if we have
    thresholds = [t_0, ..., t_{n-1}];  t_0 < ... < t_{n-1}
  where n = num_thresholds, and if we can compute the bucket function
    B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )
  then we get
    C(t_i) = sum( B(j), j >= i )
  which is the reversed cumulative sum in tf.cumsum().

  We can compute B(i) efficiently by taking advantage of the fact that
  our thresholds are evenly distributed, in that
    width = 1.0 / (num_thresholds - 1)
    thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]
  Given a prediction value p, we can map it to its bucket by
    bucket_index(p) = floor( p * (num_thresholds - 1) )
  so we can use tf.math.unsorted_segment_sum() to update the buckets in one
  pass.

  Consider following example:
  y_true = [0, 0, 1, 1]
  y_pred = [0.1, 0.5, 0.3, 0.9]
  thresholds = [0.0, 0.5, 1.0]
  num_buckets = 2   # [0.0, 1.0], (1.0, 2.0]
  bucket_index(y_pred) = tf.math.floor(y_pred * num_buckets)
                       = tf.math.floor([0.2, 1.0, 0.6, 1.8])
                       = [0, 0, 0, 1]
  # The meaning of this bucket is that if any of the label is true,
  # then 1 will be added to the corresponding bucket with the index.
  # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the
  # label for 1.8 is true, then 1 will be added to bucket 1.
  #
  # Note the second item "1.0" is floored to 0, since the value need to be
  # strictly larger than the bucket lower bound.
  # In the implementation, we use tf.math.ceil() - 1 to achieve this.
  tp_bucket_value = tf.math.unsorted_segment_sum(true_labels, bucket_indices,
                                                 num_segments=num_thresholds)
                  = [1, 1, 0]
  # For [1, 1, 0] here, it means there is 1 true value contributed by bucket 0,
  # and 1 value contributed by bucket 1. When we aggregate them to together,
  # the result become [a + b + c, b + c, c], since large thresholds will always
  # contribute to the value for smaller thresholds.
  true_positive = tf.math.cumsum(tp_bucket_value, reverse=True)
                = [2, 1, 0]

  This implementation exhibits a run time and space complexity of O(T + N),
  where T is the number of thresholds and N is the size of predictions.
  Metrics that rely on standard implementation instead exhibit a complexity of
  O(T * N).

  Args:
    variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys
      and corresponding variables to update as values.
    y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be cast
      to `bool`.
    y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
      the range `[0, 1]`.
    thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.
      It need to be evenly distributed (the diff between each element need to be
      the same).
    multi_label: Optional boolean indicating whether multidimensional
      prediction/labels should be treated as multilabel responses, or flattened
      into a single label. When True, the valus of `variables_to_update` must
      have a second dimension equal to the number of labels in y_true and
      y_pred, and those tensors must not be RaggedTensors.
    sample_weights: Optional `Tensor` whose rank is either 0, or the same rank
      as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions
      must be either `1`, or the same as the corresponding `y_true` dimension).
    label_weights: Optional tensor of non-negative weights for multilabel
      data. The weights are applied when calculating TP, FP, FN, and TN without
      explicit multilabel handling (i.e. when the data is to be flattened).
    thresholds_with_epsilon: Optional boolean indicating whether the leading and
      tailing thresholds has any epsilon added for floating point imprecisions.
      It will change how we handle the leading and tailing bucket.

  Returns:
    Update op.
  """
    num_thresholds = thresholds.shape.as_list()[0]

    if sample_weights is None:
        sample_weights = 1.0
    else:
        sample_weights = tf.__internal__.ops.broadcast_weights(
            tf.cast(sample_weights, dtype=y_pred.dtype), y_pred)
        if not multi_label:
            sample_weights = tf.reshape(sample_weights, [-1])
    if label_weights is None:
        label_weights = 1.0
    else:
        label_weights = tf.expand_dims(label_weights, 0)
        label_weights = tf.__internal__.ops.broadcast_weights(
            label_weights, y_pred)
        if not multi_label:
            label_weights = tf.reshape(label_weights, [-1])
    weights = tf.multiply(sample_weights, label_weights)

    # We shouldn't need this, but in case there are predict value that is out of
    # the range of [0.0, 1.0]
    y_pred = tf.clip_by_value(y_pred, clip_value_min=0.0, clip_value_max=1.0)

    y_true = tf.cast(tf.cast(y_true, tf.bool), y_true.dtype)
    if not multi_label:
        y_true = tf.reshape(y_true, [-1])
        y_pred = tf.reshape(y_pred, [-1])

    true_labels = tf.multiply(y_true, weights)
    false_labels = tf.multiply((1.0 - y_true), weights)

    # Compute the bucket indices for each prediction value.
    # Since the predict value has to be strictly greater than the thresholds,
    # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket.
    # We have to use math.ceil(val) - 1 for the bucket.
    bucket_indices = tf.math.ceil(y_pred * (num_thresholds - 1)) - 1

    if thresholds_with_epsilon:
        # In this case, the first bucket should actually take into account since
        # the any prediction between [0.0, 1.0] should be larger than the first
        # threshold. We change the bucket value from -1 to 0.
        bucket_indices = tf.nn.relu(bucket_indices)

    bucket_indices = tf.cast(bucket_indices, tf.int32)

    if multi_label:
        # We need to run bucket segment sum for each of the label class. In the
        # multi_label case, the rank of the label is 2. We first transpose it so
        # that the label dim becomes the first and we can parallel run though them.
        true_labels = tf.transpose(true_labels)
        false_labels = tf.transpose(false_labels)
        bucket_indices = tf.transpose(bucket_indices)

        def gather_bucket(label_and_bucket_index):
            label, bucket_index = label_and_bucket_index[
                0], label_and_bucket_index[1]
            return tf.math.unsorted_segment_sum(data=label,
                                                segment_ids=bucket_index,
                                                num_segments=num_thresholds)

        tp_bucket_v = tf.vectorized_map(gather_bucket,
                                        (true_labels, bucket_indices))
        fp_bucket_v = tf.vectorized_map(gather_bucket,
                                        (false_labels, bucket_indices))
        tp = tf.transpose(tf.cumsum(tp_bucket_v, reverse=True, axis=1))
        fp = tf.transpose(tf.cumsum(fp_bucket_v, reverse=True, axis=1))
    else:
        tp_bucket_v = tf.math.unsorted_segment_sum(data=true_labels,
                                                   segment_ids=bucket_indices,
                                                   num_segments=num_thresholds)
        fp_bucket_v = tf.math.unsorted_segment_sum(data=false_labels,
                                                   segment_ids=bucket_indices,
                                                   num_segments=num_thresholds)
        tp = tf.cumsum(tp_bucket_v, reverse=True)
        fp = tf.cumsum(fp_bucket_v, reverse=True)

    # fn = sum(true_labels) - tp
    # tn = sum(false_labels) - fp
    if (ConfusionMatrix.TRUE_NEGATIVES in variables_to_update
            or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update):
        if multi_label:
            total_true_labels = tf.reduce_sum(true_labels, axis=1)
            total_false_labels = tf.reduce_sum(false_labels, axis=1)
        else:
            total_true_labels = tf.reduce_sum(true_labels)
            total_false_labels = tf.reduce_sum(false_labels)

    update_ops = []
    if ConfusionMatrix.TRUE_POSITIVES in variables_to_update:
        variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES]
        update_ops.append(variable.assign_add(tp))
    if ConfusionMatrix.FALSE_POSITIVES in variables_to_update:
        variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES]
        update_ops.append(variable.assign_add(fp))
    if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update:
        variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES]
        tn = total_false_labels - fp
        update_ops.append(variable.assign_add(tn))
    if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update:
        variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES]
        fn = total_true_labels - tp
        update_ops.append(variable.assign_add(fn))
    return tf.group(update_ops)
Ejemplo n.º 15
0
  def _step(self) -> Dict[str, tf.Tensor]:
    # Get data from replay (dropping extras if any). Note there is no
    # extra data here because we do not insert any into Reverb.
    sample = next(self._iterator)
    o_tm1, a_tm1, r_t, d_t, o_t = sample.data[:5]
    a_t = self._policy_network(o_t)
    if self._clipping_action:
      if not a_t.dtype.is_floating:
        raise ValueError(f'Action dtype ({a_t.dtype}) is not floating.')
      a_t = tf.clip_by_value(a_t, -1., 1.)

    # Cast the additional discount to match the environment discount dtype.
    discount = tf.cast(self._discount, dtype=d_t.dtype)
    d_t = discount * d_t

    if self._use_tilde_critic:
      tilde_td_error = _td_error(
          self._tilde_critic_network, o_tm1, a_tm1, r_t, d_t, o_t, a_t)[0]
      # In the same shape as tilde_td_error.
      f_regularizer = 0.25 * tf.square(tilde_td_error)
    else:
      # Scalar.
      tilde_td_error = 0.
      f_regularizer = self._f_regularizer

    with tf.GradientTape() as tape:
      td_error, q_tm1, q_t = _td_error(
          self._critic_network, o_tm1, a_tm1, r_t, d_t, o_t, a_t)
      f = self._f_network(o_tm1, a_tm1)
      if f.shape != td_error.shape:
        raise ValueError(f'Shape of f {f.shape.as_list()} does not '
                         f'match that of td_error {td_error.shape.as_list()}')

      moment = tf.reduce_mean(f * td_error)
      f_reg_loss = tf.reduce_mean(f_regularizer * tf.square(f))
      u = moment - f_reg_loss

      # Add regularizations.

      # Regularization on critic net output values.
      if self._critic_regularizer > 0.:
        critic_reg_loss = self._critic_regularizer * (
            tf.reduce_mean(tf.square(q_tm1)) +
            tf.reduce_mean(tf.square(q_t))) / 2.
      else:
        critic_reg_loss = 0.

      # Ortho regularization on critic net.
      if self._critic_ortho_regularizer > 0.:
        critic_ortho_reg_loss = (
            self._critic_ortho_regularizer *
            _orthogonal_regularization(self._critic_network))
      else:
        critic_ortho_reg_loss = 0.

      # Ortho regularization on f net.
      if self._f_ortho_regularizer > 0.:
        f_ortho_reg_loss = (
            self._f_ortho_regularizer *
            _orthogonal_regularization(self._f_network))
      else:
        f_ortho_reg_loss = 0.

      # L2 regularization on critic net.
      if self._critic_l2_regularizer > 0.:
        critic_l2_reg_loss = (
            self._critic_l2_regularizer *
            _l2_regularization(self._critic_network))
      else:
        critic_l2_reg_loss = 0.

      # L2 regularization on f net.
      if self._f_l2_regularizer > 0.:
        f_l2_reg_loss = (
            self._f_l2_regularizer *
            _l2_regularization(self._f_network))
      else:
        f_l2_reg_loss = 0.

      loss = (u + critic_reg_loss
              + critic_ortho_reg_loss - f_ortho_reg_loss
              + critic_l2_reg_loss - f_l2_reg_loss)

    bre_mse = self._check_bellman_residual_error(q_tm1, r_t, d_t, o_t)

    # Get trainable variables.
    critic_variables = self._critic_network.trainable_variables
    f_variables = self._f_network.trainable_variables

    # Compute gradients.
    gradients = tape.gradient(loss, critic_variables + f_variables)
    critic_gradients = gradients[:len(critic_variables)]
    f_gradients = gradients[len(critic_variables):]

    # Maybe clip gradients.
    if self._clipping:
      # # clip_by_global_norm
      # critic_gradients = tf.clip_by_global_norm(critic_gradients, 40.)[0]
      # f_gradients = tf.clip_by_global_norm(f_gradients, 40.)[0]

      # clip_by_value
      critic_gradients = [tf.clip_by_value(g, -1.0, 1.0)
                          for g in critic_gradients]
      f_gradients = [tf.clip_by_value(g, -1.0, 1.0) for g in f_gradients]

    # Apply critic gradients to minimize the loss.
    self._critic_optimizer.apply(critic_gradients, critic_variables)

    # Apply f gradients to maximize the loss.
    f_gradients = [-g for g in f_gradients]
    self._f_optimizer.apply(f_gradients, f_variables)

    if self._use_tilde_critic:
      if tf.math.mod(self._num_steps, self._tilde_critic_update_period) == 0:
        source_variables = self._critic_network.variables
        tilde_variables = self._tilde_critic_network.variables

        # Make online -> tilde network update ops.
        for src, dest in zip(source_variables, tilde_variables):
          dest.assign(src)
    self._num_steps.assign_add(1)

    # Losses to track.
    results = {
        'loss': loss,
        'u': u,
        'f_reg_loss': f_reg_loss,
        'td_mse': tf.reduce_mean(tf.square(td_error)),
        'f_ms': tf.reduce_mean(tf.square(f)),
        'moment': moment,
        'global_steps': tf.convert_to_tensor(self._num_steps),
        'bre_mse': bre_mse,
    }
    if self._use_tilde_critic:
      results.update({
          'tilde_td_mse': tf.reduce_mean(tf.square(tilde_td_error))})
    if self._critic_regularizer > 0.:
      results.update({'critic_reg_loss': critic_reg_loss})
    if self._critic_ortho_regularizer > 0.:
      results.update({'critic_ortho_reg_loss': critic_ortho_reg_loss})
    if self._f_ortho_regularizer > 0.:
      results.update({'f_ortho_reg_loss': f_ortho_reg_loss})
    if self._critic_l2_regularizer > 0.:
      results.update({'critic_l2_reg_loss': critic_l2_reg_loss})
    if self._f_l2_regularizer > 0.:
      results.update({'f_l2_reg_loss': f_l2_reg_loss})
    return results
Ejemplo n.º 16
0
def get_counts(model,
               frames,
               strides,
               batch_size,
               threshold,
               within_period_threshold,
               constant_speed=False,
               median_filter=False,
               fully_periodic=False):
    """Pass frames through model and conver period predictions to count."""
    seq_len = len(frames)
    raw_scores_list = []
    scores = []
    within_period_scores_list = []

    if fully_periodic:
        within_period_threshold = 0.0

    frames = model.preprocess(frames)

    for stride in strides:
        num_batches = int(
            np.ceil(seq_len / model.num_frames / stride / batch_size))
        raw_scores_per_stride = []
        within_period_score_stride = []
        for batch_idx in range(num_batches):
            idxes = tf.range(batch_idx * model.num_frames * stride,
                             (batch_idx + batch_size) * model.num_frames *
                             stride, stride)
            idxes = tf.clip_by_value(idxes, 0, seq_len - 1)
            curr_frames = tf.gather(frames, idxes)
            curr_frames = tf.reshape(curr_frames, [
                batch_size, model.num_frames, model.image_size,
                model.image_size, 3
            ])

            raw_scores, within_period_scores, _ = model(curr_frames)
            raw_scores_per_stride.append(
                np.reshape(raw_scores.numpy(), [-1, model.num_frames // 2]))
            within_period_score_stride.append(
                np.reshape(within_period_scores.numpy(), [-1, 1]))
        raw_scores_per_stride = np.concatenate(raw_scores_per_stride, axis=0)
        raw_scores_list.append(raw_scores_per_stride)
        within_period_score_stride = np.concatenate(within_period_score_stride,
                                                    axis=0)
        pred_score, within_period_score_stride = get_score(
            raw_scores_per_stride, within_period_score_stride)
        scores.append(pred_score)
        within_period_scores_list.append(within_period_score_stride)

    # Stride chooser
    argmax_strides = np.argmax(scores)
    chosen_stride = strides[argmax_strides]
    raw_scores = np.repeat(raw_scores_list[argmax_strides],
                           chosen_stride,
                           axis=0)[:seq_len]
    within_period = np.repeat(within_period_scores_list[argmax_strides],
                              chosen_stride,
                              axis=0)[:seq_len]
    within_period_binary = np.asarray(within_period > within_period_threshold)
    if median_filter:
        within_period_binary = medfilt(within_period_binary, 5)

    # Select Periodic frames
    periodic_idxes = np.where(within_period_binary)[0]

    if constant_speed:
        # Count by averaging predictions. Smoother but
        # assumes constant speed.
        scores = tf.reduce_mean(tf.nn.softmax(raw_scores[periodic_idxes],
                                              axis=-1),
                                axis=0)
        max_period = np.argmax(scores)
        pred_score = scores[max_period]
        pred_period = chosen_stride * (max_period + 1)
        per_frame_counts = (np.asarray(seq_len * [1. / pred_period]) *
                            np.asarray(within_period_binary))
    else:
        # Count each frame. More noisy but adapts to changes in speed.
        pred_score = tf.reduce_mean(within_period)
        per_frame_periods = tf.argmax(raw_scores, axis=-1) + 1
        per_frame_counts = tf.where(
            tf.math.less(per_frame_periods, 3),
            0.0,
            tf.math.divide(
                1.0, tf.cast(chosen_stride * per_frame_periods, tf.float32)),
        )
        if median_filter:
            per_frame_counts = medfilt(per_frame_counts, 5)

        per_frame_counts *= np.asarray(within_period_binary)

        pred_period = seq_len / np.sum(per_frame_counts)

    if pred_score < threshold:
        print('No repetitions detected in video as score '
              '%0.2f is less than threshold %0.2f.' % (pred_score, threshold))
        per_frame_counts = np.asarray(len(per_frame_counts) * [0.])

    return (pred_period, pred_score, within_period, per_frame_counts,
            chosen_stride)
Ejemplo n.º 17
0
    def _build_target_quantile_values_op(self):
        """Build an op used as a target for return values at given quantiles.

    Returns:
      An op calculating the target quantile return.
    """
        batch_size = tf.shape(self._replay.rewards)[0]

        # Calculate SIL modified rewards.
        replay_action_one_hot = tf.one_hot(self._replay.actions,
                                           self.num_actions,
                                           1.,
                                           0.,
                                           name='action_one_hot')
        replay_target_q = tf.reduce_max(self._replay_target_q_values,
                                        axis=1,
                                        name='replay_chosen_target_q')
        replay_target_q_al = tf.reduce_sum(replay_action_one_hot *
                                           self._replay_target_q_values,
                                           axis=1,
                                           name='replay_chosen_target_q_al')
        comp_value = tf.math.maximum(replay_target_q_al, self._replay.returns)

        if self._clip > 0.:
            sil_bonus = self._alpha * tf.clip_by_value(
                (comp_value - replay_target_q), -self._clip, self._clip)
        else:
            sil_bonus = self._alpha * (comp_value - replay_target_q)

        # Shape of rewards: (num_tau_prime_samples x batch_size) x 1.
        rewards = (self._replay.rewards + sil_bonus)[:, None]
        rewards = tf.tile(rewards, [self.num_tau_prime_samples, 1])

        is_terminal_multiplier = 1. - tf.cast(self._replay.terminals,
                                              tf.float32)
        # Incorporate terminal state to discount factor.
        # size of gamma_with_terminal: (num_tau_prime_samples x batch_size) x 1.
        gamma_with_terminal = self.cumulative_gamma * is_terminal_multiplier
        gamma_with_terminal = tf.tile(gamma_with_terminal[:, None],
                                      [self.num_tau_prime_samples, 1])

        # Get the indices of the maximum Q-value across the action dimension.
        # Shape of replay_next_qt_argmax: (num_tau_prime_samples x batch_size) x 1.

        replay_next_qt_argmax = tf.tile(self._replay_next_qt_argmax[:, None],
                                        [self.num_tau_prime_samples, 1])

        # Shape of batch_indices: (num_tau_prime_samples x batch_size) x 1.
        batch_indices = tf.cast(
            tf.range(self.num_tau_prime_samples * batch_size)[:, None],
            tf.int64)

        # Shape of batch_indexed_target_values:
        # (num_tau_prime_samples x batch_size) x 2.
        batch_indexed_target_values = tf.concat(
            [batch_indices, replay_next_qt_argmax], axis=1)

        # Shape of next_target_values: (num_tau_prime_samples x batch_size) x 1.
        target_quantile_values = tf.gather_nd(
            self._replay_net_target_quantile_values,
            batch_indexed_target_values)[:, None]

        return rewards + gamma_with_terminal * target_quantile_values
Ejemplo n.º 18
0
    def _loop_build_sub_tree(self, directions, integrator,
                             current_step_meta_info, iter_,
                             energy_diff_sum_previous,
                             momentum_cumsum_previous, leapfrogs_taken,
                             prev_tree_state, candidate_tree_state,
                             continue_tree_previous, not_divergent_previous,
                             momentum_state_memory):
        """Base case in tree doubling."""
        with tf.name_scope('loop_build_sub_tree'):
            # Take one leapfrog step in the direction v and check divergence
            [
                next_momentum_parts, next_state_parts, next_target,
                next_target_grad_parts
            ] = integrator(prev_tree_state.momentum, prev_tree_state.state,
                           prev_tree_state.target,
                           prev_tree_state.target_grad_parts)

            next_tree_state = TreeDoublingState(
                momentum=next_momentum_parts,
                state=next_state_parts,
                target=next_target,
                target_grad_parts=next_target_grad_parts)
            momentum_cumsum = [
                p0 + p1 for p0, p1 in zip(momentum_cumsum_previous,
                                          next_momentum_parts)
            ]
            # If the tree have not yet terminated previously, we count this leapfrog.
            leapfrogs_taken = tf.where(continue_tree_previous,
                                       leapfrogs_taken + 1, leapfrogs_taken)

            write_instruction = current_step_meta_info.write_instruction
            read_instruction = current_step_meta_info.read_instruction
            init_energy = current_step_meta_info.init_energy

            # Save state and momentum at odd step, check U turn at even step.
            # Note that here we also write to a Placeholder at even step
            write_index = tf.where(tf.equal(iter_ % 2, 0),
                                   write_instruction.gather([iter_ // 2]),
                                   self.max_tree_depth)

            if GENERALIZED_UTURN:
                state_to_write = momentum_cumsum
            else:
                state_to_write = next_state_parts

            momentum_state_memory = MomentumStateSwap(
                momentum_swap=[
                    tf.tensor_scatter_nd_update(old, [write_index], [new])
                    for old, new in zip(momentum_state_memory.momentum_swap,
                                        next_momentum_parts)
                ],
                state_swap=[
                    tf.tensor_scatter_nd_update(old, [write_index], [new])
                    for old, new in zip(momentum_state_memory.state_swap,
                                        state_to_write)
                ])
            batch_shape = prefer_static.shape(next_target)
            has_not_u_turn_at_even_step = tf.ones(batch_shape, dtype=tf.bool)

            read_index = read_instruction.gather([iter_ // 2])[0]
            no_u_turns_within_tree = tf.cond(
                tf.equal(iter_ % 2, 0),
                lambda: has_not_u_turn_at_even_step,
                lambda: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
                    read_index,
                    directions,
                    momentum_state_memory,
                    next_momentum_parts,
                    state_to_write,
                    has_not_u_turn_at_even_step,
                    log_prob_rank=prefer_static.rank(next_target)))

            energy = compute_hamiltonian(next_target, next_momentum_parts)
            current_energy = tf.where(tf.math.is_nan(energy),
                                      tf.constant(-np.inf, dtype=energy.dtype),
                                      energy)
            energy_diff = current_energy - init_energy

            if MULTINOMIAL_SAMPLE:
                not_divergent = -energy_diff < self.max_energy_diff
                weight_sum = log_add_exp(candidate_tree_state.weight,
                                         energy_diff)
                log_accept_thresh = energy_diff - weight_sum
            else:
                log_slice_sample = current_step_meta_info.log_slice_sample
                not_divergent = log_slice_sample - energy_diff < self.max_energy_diff
                # Uniform sampling on the trajectory within the subtree across valid
                # samples.
                is_valid = log_slice_sample <= energy_diff
                weight_sum = tf.where(is_valid,
                                      candidate_tree_state.weight + 1,
                                      candidate_tree_state.weight)
                log_accept_thresh = tf.where(
                    is_valid,
                    -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)),
                    tf.constant(-np.inf, dtype=tf.float32))
            u = tf.math.log1p(-tf.random.uniform(shape=batch_shape,
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            next_candidate_tree_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(is_sample_accepted,
                                                  prefer_static.rank(s0)), s0,
                        s1) for s0, s1 in zip(next_state_parts,
                                              candidate_tree_state.state)
                ],
                target=tf.where(
                    _rightmost_expand_to_rank(is_sample_accepted,
                                              prefer_static.rank(next_target)),
                    next_target, candidate_tree_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(is_sample_accepted,
                                                  prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            next_target_grad_parts,
                            candidate_tree_state.target_grad_parts)
                ],
                energy=tf.where(
                    _rightmost_expand_to_rank(is_sample_accepted,
                                              prefer_static.rank(next_target)),
                    current_energy, init_energy),
                weight=weight_sum)

            continue_tree = not_divergent & continue_tree_previous
            continue_tree_next = no_u_turns_within_tree & continue_tree

            not_divergent_tokeep = tf.where(
                continue_tree_previous, not_divergent,
                tf.ones(batch_shape, dtype=tf.bool))

            # min(1., exp(energy_diff)).
            exp_energy_diff = tf.clip_by_value(tf.exp(energy_diff), 0., 1.)
            energy_diff_sum = tf.where(
                continue_tree, energy_diff_sum_previous + exp_energy_diff,
                energy_diff_sum_previous)

            return (
                iter_ + 1,
                energy_diff_sum,
                momentum_cumsum,
                leapfrogs_taken,
                next_tree_state,
                next_candidate_tree_state,
                continue_tree_next,
                not_divergent_previous & not_divergent_tokeep,
                momentum_state_memory,
            )
Ejemplo n.º 19
0
def _clip_dirichlet_parameters(x):
    """Clips Dirichlet param for numerically stable KL and nonzero samples."""
    return tf.clip_by_value(x, .1, 1e3)
Ejemplo n.º 20
0
def _bates_cdf(total_count, low, high, dtype, value):
    """Compute the Bates cdf.

  Internally, the (standard, unnormalized) cdf is computed by the formula

  ```none
  pdf = sum_{k=0}^j (-1)^k (n choose k) (nx - k)^n
  ```

  where
  * `n = total_count`,
  * `x = value` the value to compute the cumulative probability of, and
  * `j = floor(nx)`.

  This is shifted to `[low, high]` and normalized. Since the pdf is symmetric,
  we have `cdf(x) = 1 - cdf(1 - x)` for `x > .5`, hence we only compute the left
  half, which keeps the number of terms lower.

  Computation is batched, using `tf.math.segment_sum()`. For this reason this is
  not compatible with `tf.vectorized_map()`.

  All input parameters should have compatible dtypes and shapes.

  Args:
    total_count: `Tensor` with integer values, as given to the `Bates`
      constructor.
    low: Float `Tensor`, as given to the `Bates` constructor.
    high: Float `Tensor`, as given to the `Bates` constructor.
    dtype: The dtype of the output.
    value: Float `Tensor`. Input value to `cdf()`.
  Returns:
    cdf: Float `Tensor`. See above formula.
  """
    total_count = tf.cast(total_count, dtype)
    low = tf.convert_to_tensor(low)
    high = tf.convert_to_tensor(high)

    # Warn the user if they try to compute a pdf with high `total_count`.  This
    # warning is here instead of `_parameter_control_dependencies()` because
    # nested calls to `_name_and_control_scope` (e.g. `log_survival_function`) can
    # result in multiple warnings being added and multiple tensor
    # conversions. Also `sample()` does not have the same numerical issues.
    with tf.control_dependencies([_stability_limit_tensor(total_count,
                                                          dtype)]):
        # Center and adjust `value` using limits and symmetry.
        value_centered = (value - low) / (high - low)
        value_adj = tf.clip_by_value(value_centered, 0., 1.)
        value_adj = tf.where(value_adj < .5, value_adj, 1. - value_adj)
        value_adj = tf.where(tf.math.is_finite(value_adj), value_adj, 0.)
        # Flatten to make segments; need to broadcast before flattening.
        shape = ps.broadcast_shape(ps.shape(value_adj), ps.shape(total_count))
        total_count_b = ps.broadcast_to(total_count, shape)
        total_count_x_value_adj_b = total_count * value_adj
        total_count_f = tf.reshape(total_count_b, [-1])
        total_count_x_value_adj_f = tf.reshape(total_count_x_value_adj_b, [-1])
        # Create segmented terms of summation.
        num_terms_f = tf.cast(tf.math.floor(total_count_x_value_adj_f + 1),
                              dtype=tf.int32)
        term_idx_s = tf.cast(_segmented_range(num_terms_f), dtype)  # aka `k`
        total_count_s = tf.repeat(total_count_f, num_terms_f)
        total_count_x_value_adj_s = tf.repeat(total_count_x_value_adj_f,
                                              num_terms_f)
        terms = (tf.cast(-1., dtype)**term_idx_s *
                 (1. / ((total_count_s + 1.) * tf.math.exp(
                     tfp_math.lbeta(total_count_s - term_idx_s + 1.,
                                    term_idx_s + 1.)))) *
                 (total_count_x_value_adj_s - term_idx_s)**total_count_s)
        # Segment sum.
        segment_ids = tf.repeat(tf.range(tf.size(num_terms_f)), num_terms_f)
        cdf_s = tf.math.segment_sum(terms, segment_ids)
        # Reshape back.
        cdf = tf.reshape(cdf_s, shape)
        # Normalize.
        cdf = cdf / tf.math.exp(
            tf.math.lgamma(total_count_b + tf.cast(1., dtype)))
        # cdf symmetry adjustment: cdf(x) = 1 - cdf(1 - x) for x > 0.5
        cdf = tf.where(value_centered > .5, 1. - cdf, cdf)
        # Fix out-of-support queries.
        cdf = tf.where(value_centered < 0., tf.cast(0., dtype), cdf)
        cdf = tf.where(value_centered > 1., tf.cast(1., dtype), cdf)
        cdf = tf.where(tf.math.is_finite(value_centered), cdf, np.nan)
        return cdf
Ejemplo n.º 21
0
def interpolate1d(x, values, tangents):
    r"""Perform cubic hermite spline interpolation on a 1D spline.

  The x coordinates of the spline knots are at [0 : 1 : len(values)-1].
  Queries outside of the range of the spline are computed using linear
  extrapolation. See https://en.wikipedia.org/wiki/Cubic_Hermite_spline
  for details, where "x" corresponds to `x`, "p" corresponds to `values`, and
  "m" corresponds to `tangents`.

  Args:
    x: A tensor of any size of single or double precision floats containing the
      set of values to be used for interpolation into the spline.
    values: A vector of single or double precision floats containing the value
      of each knot of the spline being interpolated into. Must be the same
      length as `tangents` and the same type as `x`.
    tangents: A vector of single or double precision floats containing the
      tangent (derivative) of each knot of the spline being interpolated into.
      Must be the same length as `values` and the same type as `x`.

  Returns:
    The result of interpolating along the spline defined by `values`, and
    `tangents`, using `x` as the query values. Will be the same length and type
    as `x`.
  """
    # `values` and `tangents` must have the same type as `x`.
    tf.debugging.assert_type(values, x.dtype)
    tf.debugging.assert_type(tangents, x.dtype)
    float_dtype = x.dtype
    assert_ops = [
        # `values` must be a vector.
        tf.Assert(tf.equal(tf.rank(values), 1), [tf.shape(values)]),
        # `tangents` must be a vector.
        tf.Assert(tf.equal(tf.rank(tangents), 1), [tf.shape(values)]),
        # `values` and `tangents` must have the same length.
        tf.Assert(
            tf.equal(tf.shape(values)[0],
                     tf.shape(tangents)[0]),
            [tf.shape(values)[0], tf.shape(tangents)[0]]),
    ]
    with tf.control_dependencies(assert_ops):
        # Find the indices of the knots below and above each x.
        x_lo = tf.cast(
            tf.floor(
                tf.clip_by_value(x, 0.,
                                 tf.cast(tf.shape(values)[0] - 2,
                                         float_dtype))), tf.int32)
        x_hi = x_lo + 1

        # Compute the relative distance between each `x` and the knot below it.
        t = x - tf.cast(x_lo, float_dtype)

        # Compute the cubic hermite expansion of `t`.
        t_sq = tf.square(t)
        t_cu = t * t_sq
        h01 = -2. * t_cu + 3. * t_sq
        h00 = 1. - h01
        h11 = t_cu - t_sq
        h10 = h11 - t_sq + t

        # Linearly extrapolate above and below the extents of the spline for all
        # values.
        value_before = tangents[0] * t + values[0]
        value_after = tangents[-1] * (t - 1.) + values[-1]

        # Cubically interpolate between the knots below and above each query point.
        neighbor_values_lo = tf.gather(values, x_lo)
        neighbor_values_hi = tf.gather(values, x_hi)
        neighbor_tangents_lo = tf.gather(tangents, x_lo)
        neighbor_tangents_hi = tf.gather(tangents, x_hi)
        value_mid = (neighbor_values_lo * h00 + neighbor_values_hi * h01 +
                     neighbor_tangents_lo * h10 + neighbor_tangents_hi * h11)

        # Return the interpolated or extrapolated values for each query point,
        # depending on whether or not the query lies within the span of the spline.
        return tf.where(t < 0., value_before,
                        tf.where(t > 1., value_after, value_mid))
Ejemplo n.º 22
0
def _update_trajectory_grad(previous_kernel_results,
                            previous_state,
                            proposed_state,
                            proposed_velocity,
                            trajectory_jitter,
                            accept_prob,
                            step_size,
                            criterion_fn,
                            max_leapfrog_steps,
                            experimental_shard_axis_names=None,
                            experimental_chain_axis_names=None):
    """Updates the trajectory length."""

    # Compute criterion grads.
    def leapfrog_action(dt):
        # This represents the effect on the criterion value as the state follows the
        # proposed velocity. This implicitly assumes an identity mass matrix.
        def adjust_state(x, v, shard_axes=None):
            broadcasted_dt = distribute_lib.pbroadcast(
                bu.left_justified_expand_dims_like(dt, v), shard_axes)
            return x + broadcasted_dt * v

        adjusted_state = _map_structure_up_to_with_axes(
            proposed_state,
            adjust_state,
            proposed_state,
            proposed_velocity,
            experimental_shard_axis_names=experimental_shard_axis_names)
        return criterion_fn(previous_state, adjusted_state, accept_prob)

    criterion, trajectory_grad = gradient.value_and_gradient(
        leapfrog_action, tf.zeros_like(accept_prob))
    trajectory_grad *= trajectory_jitter

    # Weight by acceptance probability.
    experimental_chain_axis_names = distribute_lib.canonicalize_named_axis(
        experimental_chain_axis_names)
    trajectory_grad = tf.where(accept_prob > 1e-4, trajectory_grad, 0.)
    trajectory_grad = tf.where(tf.math.is_finite(trajectory_grad),
                               trajectory_grad, 0.)
    trajectory_grad = (_reduce_sum_with_axes(
        trajectory_grad * accept_prob, None, experimental_chain_axis_names) /
                       _reduce_sum_with_axes(accept_prob + 1e-20, None,
                                             experimental_chain_axis_names))

    # Compute Adam/RMSProp step size.
    dtype = previous_kernel_results.adaptation_rate.dtype
    iteration_f = tf.cast(previous_kernel_results.step, dtype) + 1.
    msg_adaptation_rate = 0.05
    new_averaged_sq_grad = (
        (1 - msg_adaptation_rate) * previous_kernel_results.averaged_sq_grad +
        msg_adaptation_rate * trajectory_grad**2)
    adjusted_averaged_sq_grad = new_averaged_sq_grad / (
        1. - (1 - msg_adaptation_rate)**iteration_f)
    trajectory_step_size = (previous_kernel_results.adaptation_rate /
                            tf.sqrt(adjusted_averaged_sq_grad + 1e-20))

    # Apply the gradient. Clip absolute value to ~log(2)/2.
    log_update = tf.clip_by_value(trajectory_step_size * trajectory_grad,
                                  -0.35, 0.35)
    new_max_trajectory_length = previous_kernel_results.max_trajectory_length * tf.exp(
        log_update)

    # Iterate averaging.
    average_weight = iteration_f**(-0.5)
    new_averaged_max_trajectory_length = tf.exp(
        average_weight * tf.math.log(new_max_trajectory_length) +
        (1 - average_weight) *
        tf.math.log(1e-10 +
                    previous_kernel_results.averaged_max_trajectory_length))

    # Clip the maximum trajectory length.
    new_max_trajectory_length = _clip_max_trajectory_length(
        new_max_trajectory_length, step_size,
        previous_kernel_results.adaptation_rate, max_leapfrog_steps)

    return previous_kernel_results._replace(
        criterion=criterion,
        max_trajectory_length=new_max_trajectory_length,
        averaged_sq_grad=new_averaged_sq_grad,
        averaged_max_trajectory_length=new_averaged_max_trajectory_length)
Ejemplo n.º 23
0
 def __call__(self, states, actions):
     dist, _ = self.actor.get_dist_and_mode(states)
     actions = tf.clip_by_value(actions, 1e-4 + self.action_spec.minimum,
                                -1e-4 + self.action_spec.maximum)
     log_probs = dist.log_prob(actions)
     return dist, log_probs
Ejemplo n.º 24
0
def find_bins(x,
              edges,
              extend_lower_interval=False,
              extend_upper_interval=False,
              dtype=None,
              name=None):
    """Bin values into discrete intervals.

  Given `edges = [c0, ..., cK]`, defining intervals
  `I0 = [c0, c1)`, `I1 = [c1, c2)`, ..., `I_{K-1} = [c_{K-1}, cK]`,
  This function returns `bins`, such that:
  `edges[bins[i]] <= x[i] < edges[bins[i] + 1]`.

  Args:
    x:  Numeric `N-D` `Tensor` with `N > 0`.
    edges:  `Tensor` of same `dtype` as `x`.  The first dimension indexes edges
      of intervals.  Must either be `1-D` or have
      `x.shape[1:] == edges.shape[1:]`.  If `rank(edges) > 1`, `edges[k]`
      designates a shape `edges.shape[1:]` `Tensor` of bin edges for the
      corresponding dimensions of `x`.
    extend_lower_interval:  Python `bool`.  If `True`, extend the lowest
      interval `I0` to `(-inf, c1]`.
    extend_upper_interval:  Python `bool`.  If `True`, extend the upper
      interval `I_{K-1}` to `[c_{K-1}, +inf)`.
    dtype: The output type (`int32` or `int64`). `Default value:` `x.dtype`.
      This effects the output values when `x` is below/above the intervals,
      which will be `-1/K+1` for `int` types and `NaN` for `float`s.
      At indices where `x` is `NaN`, the output values will be `0` for `int`
      types and `NaN` for floats.
    name:  A Python string name to prepend to created ops. Default: 'find_bins'

  Returns:
    bins: `Tensor` with same `shape` as `x` and `dtype`.
      Has whole number values.  `bins[i] = k` means the `x[i]` falls into the
      `kth` bin, ie, `edges[bins[i]] <= x[i] < edges[bins[i] + 1]`.

  Raises:
    ValueError:  If `edges.shape[0]` is determined to be less than 2.

  #### Examples

  Cut a `1-D` array

  ```python
  x = [0., 5., 6., 10., 20.]
  edges = [0., 5., 10.]
  tfp.stats.find_bins(x, edges)
  ==> [0., 0., 1., 1., np.nan]
  ```

  Cut `x` into its deciles

  ```python
  x = tf.random.uniform(shape=(100, 200))
  decile_edges = tfp.stats.quantiles(x, num_quantiles=10)
  bins = tfp.stats.find_bins(x, edges=decile_edges)
  bins.shape
  ==> (100, 200)
  tf.reduce_mean(bins == 0.)
  ==> approximately 0.1
  tf.reduce_mean(bins == 1.)
  ==> approximately 0.1
  ```

  """
    # TFP users may be surprised to see the "action" in the leftmost dim of
    # edges, rather than the rightmost (event) dim.  Why?
    # 1. Most likely you created edges by getting quantiles over samples, and
    #    quantile/percentile return these edges in the leftmost (sample) dim.
    # 2. Say you have event_shape = [5], then we expect the bin will be different
    #    for all 5 events, so the index of the bin should not be in the event dim.
    with tf.name_scope(name or 'find_bins'):
        in_type = dtype_util.common_dtype([x, edges], dtype_hint=tf.float32)
        edges = tf.convert_to_tensor(edges, name='edges', dtype=in_type)
        x = tf.convert_to_tensor(x, name='x', dtype=in_type)

        if (tf.compat.dimension_value(edges.shape[0]) is not None
                and tf.compat.dimension_value(edges.shape[0]) < 2):
            raise ValueError(
                'First dimension of `edges` must have length > 1 to index 1 or '
                'more bin. Found: {}'.format(edges.shape))

        flattening_x = (tensorshape_util.rank(edges.shape) == 1
                        and tensorshape_util.rank(x.shape) > 1)

        if flattening_x:
            x_orig_shape = ps.shape(x)
            x = tf.reshape(x, [-1])

        if dtype is None:
            dtype = in_type
        dtype = tf.as_dtype(dtype)

        # Move first dims into the rightmost.
        x_permed = distribution_util.rotate_transpose(x, shift=-1)
        edges_permed = distribution_util.rotate_transpose(edges, shift=-1)

        # If...
        #   x_permed = [0, 1, 6., 10]
        #   edges = [0, 5, 10.]
        #   ==> almost_output = [0, 1, 2, 2]
        searchsorted_type = dtype if dtype in [tf.int32, tf.int64] else None
        almost_output_permed = tf.searchsorted(sorted_sequence=edges_permed,
                                               values=x_permed,
                                               side='right',
                                               out_type=searchsorted_type)
        # Move the rightmost dims back to the leftmost.
        almost_output = tf.cast(
            distribution_util.rotate_transpose(almost_output_permed, shift=1),
            dtype)

        # In above example, we want [0, 0, 1, 1], so correct this here.
        bins = tf.clip_by_value(almost_output - 1, tf.cast(0, dtype),
                                tf.cast(tf.shape(edges)[0] - 2, dtype))

        if not extend_lower_interval:
            low_fill = np.nan if dtype_util.is_floating(dtype) else -1
            bins = tf.where(x < tf.expand_dims(edges[0], 0),
                            tf.cast(low_fill, dtype), bins)

        if not extend_upper_interval:
            up_fill = (np.nan if dtype_util.is_floating(dtype) else
                       tf.shape(edges)[0] - 1)
            bins = tf.where(x > tf.expand_dims(edges[-1], 0),
                            tf.cast(up_fill, dtype), bins)

        if flattening_x:
            bins = tf.reshape(bins, x_orig_shape)

        return bins
Ejemplo n.º 25
0
        def maybe_step(accepted, diagnostics, iterand, solver_internal_state):
            """Takes a single step only if the outcome has a low enough error."""
            [
                num_jacobian_evaluations, num_matrix_factorizations,
                num_ode_fn_evaluations, status
            ] = diagnostics
            [
                jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps,
                num_steps_same_size, should_update_jacobian,
                should_update_step_size, time, unitary, upper
            ] = iterand
            [backward_differences, order, step_size] = solver_internal_state

            if max_num_steps is not None:
                status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0)

            backward_differences = tf1.where(
                should_update_step_size,
                bdf_util.interpolate_backward_differences(
                    backward_differences, order, new_step_size / step_size),
                backward_differences)
            step_size = tf1.where(should_update_step_size, new_step_size,
                                  step_size)
            should_update_factorization = should_update_step_size
            num_steps_same_size = tf1.where(should_update_step_size, 0,
                                            num_steps_same_size)

            def update_factorization():
                return bdf_util.newton_qr(
                    jacobian_mat, newton_coefficients_array.read(order),
                    step_size)

            if self._evaluate_jacobian_lazily:

                def update_jacobian_and_factorization():
                    new_jacobian_mat = jacobian_fn_mat(time,
                                                       backward_differences[0])
                    new_unitary, new_upper = update_factorization()
                    return [
                        new_jacobian_mat, True, num_jacobian_evaluations + 1,
                        new_unitary, new_upper
                    ]

                def maybe_update_factorization():
                    new_unitary, new_upper = tf.cond(
                        should_update_factorization, update_factorization,
                        lambda: [unitary, upper])
                    return [
                        jacobian_mat, jacobian_is_up_to_date,
                        num_jacobian_evaluations, new_unitary, new_upper
                    ]

                [
                    jacobian_mat, jacobian_is_up_to_date,
                    num_jacobian_evaluations, unitary, upper
                ] = tf.cond(should_update_jacobian,
                            update_jacobian_and_factorization,
                            maybe_update_factorization)
            else:
                unitary, upper = update_factorization()
                num_matrix_factorizations += 1

            tol = p.atol + p.rtol * tf.abs(backward_differences[0])
            newton_tol = newton_tol_factor * tf.norm(tol)

            [
                newton_converged, next_backward_difference, next_state_vec,
                newton_num_iters
            ] = bdf_util.newton(backward_differences, max_num_newton_iters,
                                newton_coefficients_array.read(order),
                                p.ode_fn_vec, order, step_size, time,
                                newton_tol, unitary, upper)
            num_steps += 1
            num_ode_fn_evaluations += newton_num_iters

            # If Newton's method failed and the Jacobian was up to date, decrease the
            # step size.
            newton_failed = tf.logical_not(newton_converged)
            should_update_step_size = newton_failed & jacobian_is_up_to_date
            new_step_size = step_size * tf1.where(should_update_step_size,
                                                  newton_step_size_factor, 1.)

            # If Newton's method failed and the Jacobian was NOT up to date, update
            # the Jacobian.
            should_update_jacobian = newton_failed & tf.logical_not(
                jacobian_is_up_to_date)

            error_ratio = tf1.where(
                newton_converged,
                bdf_util.error_ratio(next_backward_difference,
                                     error_coefficients_array.read(order),
                                     tol), np.nan)
            accepted = error_ratio < 1.
            converged_and_rejected = newton_converged & tf.logical_not(
                accepted)

            # If Newton's method converged but the solution was NOT accepted, decrease
            # the step size.
            new_step_size = tf1.where(
                converged_and_rejected,
                util.next_step_size(step_size, order, error_ratio,
                                    p.safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = should_update_step_size | converged_and_rejected

            # If Newton's method converged and the solution was accepted, update the
            # matrix of backward differences.
            time = tf1.where(accepted, time + step_size, time)
            backward_differences = tf1.where(
                accepted,
                bdf_util.update_backward_differences(backward_differences,
                                                     next_backward_difference,
                                                     next_state_vec, order),
                backward_differences)
            jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not(
                accepted)
            num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1,
                                            num_steps_same_size)

            # Order and step size are only updated if we have taken strictly more than
            # order + 1 steps of the same size. This is to prevent the order from
            # being throttled.
            should_update_order_and_step_size = accepted & (num_steps_same_size
                                                            > order + 1)

            backward_differences_array = tf.TensorArray(
                backward_differences.dtype,
                size=bdf_util.MAX_ORDER + 3,
                clear_after_read=False,
                element_shape=next_backward_difference.get_shape()).unstack(
                    backward_differences)
            new_order = order
            new_error_ratio = error_ratio
            for offset in [-1, +1]:
                proposed_order = tf.clip_by_value(order + offset, 1, max_order)
                proposed_error_ratio = bdf_util.error_ratio(
                    backward_differences_array.read(proposed_order + 1),
                    error_coefficients_array.read(proposed_order), tol)
                proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio
                new_order = tf1.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_order, new_order)
                new_error_ratio = tf1.where(
                    should_update_order_and_step_size
                    & proposed_error_ratio_is_lower, proposed_error_ratio,
                    new_error_ratio)
            order = new_order
            error_ratio = new_error_ratio

            new_step_size = tf1.where(
                should_update_order_and_step_size,
                util.next_step_size(step_size, order, error_ratio,
                                    p.safety_factor, min_step_size_factor,
                                    max_step_size_factor), new_step_size)
            should_update_step_size = (should_update_step_size
                                       | should_update_order_and_step_size)

            diagnostics = _BDFDiagnostics(num_jacobian_evaluations,
                                          num_matrix_factorizations,
                                          num_ode_fn_evaluations, status)
            iterand = _BDFIterand(jacobian_mat, jacobian_is_up_to_date,
                                  new_step_size, num_steps,
                                  num_steps_same_size, should_update_jacobian,
                                  should_update_step_size, time, unitary,
                                  upper)
            solver_internal_state = _BDFSolverInternalState(
                backward_differences, order, step_size)
            return accepted, diagnostics, iterand, solver_internal_state
Ejemplo n.º 26
0
def _interp_regular_1d_grid_impl(x,
                                 x_ref_min,
                                 x_ref_max,
                                 y_ref,
                                 axis=-1,
                                 batch_y_ref=False,
                                 fill_value='constant_extension',
                                 fill_value_below=None,
                                 fill_value_above=None,
                                 grid_regularizing_transform=None,
                                 name=None):
    """1-D interpolation that works with/without batching."""
    # Note: we do *not* make the no-batch version a special case of the batch
    # version, because that would an inefficient use of batch_gather with
    # unnecessarily broadcast args.
    with tf.name_scope(name or 'interp_regular_1d_grid_impl'):

        # Arg checking.
        allowed_fv_st = ('constant_extension', 'extrapolate')
        for fv in (fill_value, fill_value_below, fill_value_above):
            if isinstance(fv, str) and fv not in allowed_fv_st:
                raise ValueError(
                    'A fill value ({}) was not an allowed string ({})'.format(
                        fv, allowed_fv_st))

        # Separate value fills for below/above incurs extra cost, so keep track of
        # whether this is needed.
        need_separate_fills = (
            fill_value_above is not None or fill_value_below is not None or
            fill_value == 'extrapolate'  # always requries separate below/above
        )
        if need_separate_fills and fill_value_above is None:
            fill_value_above = fill_value
        if need_separate_fills and fill_value_below is None:
            fill_value_below = fill_value

        dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
                                        dtype_hint=tf.float32)
        x = tf.convert_to_tensor(x, name='x', dtype=dtype)

        x_ref_min = tf.convert_to_tensor(x_ref_min,
                                         name='x_ref_min',
                                         dtype=dtype)
        x_ref_max = tf.convert_to_tensor(x_ref_max,
                                         name='x_ref_max',
                                         dtype=dtype)
        if not batch_y_ref:
            _assert_ndims_statically(x_ref_min, expect_ndims=0)
            _assert_ndims_statically(x_ref_max, expect_ndims=0)

        y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype)

        if batch_y_ref:
            # If we're batching,
            #   x.shape ~ [A1,...,AN, D],  x_ref_min/max.shape ~ [A1,...,AN]
            # So to add together we'll append a singleton.
            # If not batching, x_ref_min/max are scalar, so this isn't an issue,
            # moreover, if not batching, x can be scalar, and expanding x_ref_min/max
            # would cause a bad expansion of x when added to x (confused yet?).
            x_ref_min = x_ref_min[..., tf.newaxis]
            x_ref_max = x_ref_max[..., tf.newaxis]

        axis = tf.convert_to_tensor(axis, name='axis', dtype=tf.int32)
        axis = prefer_static.non_negative_axis(axis, tf.rank(y_ref))
        _assert_ndims_statically(axis, expect_ndims=0)

        ny = tf.cast(tf.shape(y_ref)[axis], dtype)

        # Map [x_ref_min, x_ref_max] to [0, ny - 1].
        # This is the (fractional) index of x.
        if grid_regularizing_transform is None:
            g = lambda x: x
        else:
            g = grid_regularizing_transform
        fractional_idx = ((g(x) - g(x_ref_min)) /
                          (g(x_ref_max) - g(x_ref_min)))
        x_idx_unclipped = fractional_idx * (ny - 1)

        # Wherever x is NaN, x_idx_unclipped will be NaN as well.
        # Keep track of the nan indices here (so we can impute NaN later).
        # Also eliminate any NaN indices, since there is not NaN in 32bit.
        nan_idx = tf.math.is_nan(x_idx_unclipped)
        zero = tf.zeros((), dtype=dtype)
        x_idx_unclipped = tf.where(nan_idx, zero, x_idx_unclipped)
        x_idx = tf.clip_by_value(x_idx_unclipped, zero, ny - 1)

        # Get the index above and below x_idx.
        # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx),
        # however, this results in idx_below == idx_above whenever x is on a grid.
        # This in turn results in y_ref_below == y_ref_above, and then the gradient
        # at this point is zero.  So here we 'jitter' one of idx_below, idx_above,
        # so that they are at different values.  This jittering does not affect the
        # interpolated value, but does make the gradient nonzero (unless of course
        # the y_ref values are the same).
        idx_below = tf.floor(x_idx)
        idx_above = tf.minimum(idx_below + 1, ny - 1)
        idx_below = tf.maximum(idx_above - 1, 0)

        # These are the values of y_ref corresponding to above/below indices.
        idx_below_int32 = tf.cast(idx_below, dtype=tf.int32)
        idx_above_int32 = tf.cast(idx_above, dtype=tf.int32)
        if batch_y_ref:
            # If y_ref.shape ~ [A1,...,AN, C, B1,...,BN],
            # and x.shape, x_ref_min/max.shape ~ [A1,...,AN, D]
            # Then y_ref_below.shape ~ [A1,...,AN, D, B1,...,BN]
            y_ref_below = _batch_gather_with_broadcast(y_ref, idx_below_int32,
                                                       axis)
            y_ref_above = _batch_gather_with_broadcast(y_ref, idx_above_int32,
                                                       axis)
        else:
            # Here, y_ref_below.shape =
            #   y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:]
            y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis)
            y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis)

        # Use t to get a convex combination of the below/above values.
        t = x_idx - idx_below

        # x, and tensors shaped like x, need to be added to, and selected with
        # (using tf.where) the output y.  This requires appending singletons.
        # Make functions appropriate for batch/no-batch.
        if batch_y_ref:
            # In the non-batch case, the output shape is going to be
            #   y_ref.shape[:axis] + x.shape + y_ref.shape[axis+1:]
            expand_x_fn = _make_expand_x_fn_for_batch_interpolation(
                y_ref, axis)
        else:
            # In the batch case, the output shape is going to be
            #   Broadcast(y_ref.shape[:axis], x.shape[:-1]) +
            #   x.shape[-1:] +  y_ref.shape[axis+1:]
            expand_x_fn = _make_expand_x_fn_for_non_batch_interpolation(
                y_ref, axis)

        t = expand_x_fn(t)
        nan_idx = expand_x_fn(nan_idx, broadcast=True)
        x_idx_unclipped = expand_x_fn(x_idx_unclipped, broadcast=True)

        y = t * y_ref_above + (1 - t) * y_ref_below

        # Now begins a long excursion to fill values outside [x_min, x_max].

        # Re-insert NaN wherever x was NaN.
        y = tf.where(nan_idx, tf.constant(np.nan, y.dtype), y)

        if not need_separate_fills:
            if fill_value == 'constant_extension':
                pass  # Already handled by clipping x_idx_unclipped.
            else:
                y = tf.where(
                    (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1),
                    fill_value, y)
        else:
            # Fill values below x_ref_min <==> x_idx_unclipped < 0.
            if fill_value_below == 'constant_extension':
                pass  # Already handled by the clipping that created x_idx_unclipped.
            elif fill_value_below == 'extrapolate':
                if batch_y_ref:
                    # For every batch member, gather the first two elements of y across
                    # `axis`.
                    y_0 = tf.gather(y_ref, [0], axis=axis)
                    y_1 = tf.gather(y_ref, [1], axis=axis)
                else:
                    # If not batching, we want to gather the first two elements, just like
                    # above.  However, these results need to be replicated for every
                    # member of x.  An easy way to do that is to gather using
                    # indices = zeros/ones(x.shape).
                    y_0 = tf.gather(y_ref,
                                    tf.zeros(tf.shape(x), dtype=tf.int32),
                                    axis=axis)
                    y_1 = tf.gather(y_ref,
                                    tf.ones(tf.shape(x), dtype=tf.int32),
                                    axis=axis)
                x_delta = (x_ref_max - x_ref_min) / (ny - 1)
                x_factor = expand_x_fn((x - x_ref_min) / x_delta,
                                       broadcast=True)
                y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0),
                             y)
            else:
                y = tf.where(x_idx_unclipped < 0, fill_value_below, y)
            # Fill values above x_ref_min <==> x_idx_unclipped > ny - 1.
            if fill_value_above == 'constant_extension':
                pass  # Already handled by the clipping that created x_idx_unclipped.
            elif fill_value_above == 'extrapolate':
                ny_int32 = tf.shape(y_ref)[axis]
                if batch_y_ref:
                    y_n1 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 1],
                                     axis=axis)
                    y_n2 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 2],
                                     axis=axis)
                else:
                    y_n1 = tf.gather(y_ref,
                                     tf.fill(tf.shape(x), ny_int32 - 1),
                                     axis=axis)
                    y_n2 = tf.gather(y_ref,
                                     tf.fill(tf.shape(x), ny_int32 - 2),
                                     axis=axis)
                x_delta = (x_ref_max - x_ref_min) / (ny - 1)
                x_factor = expand_x_fn((x - x_ref_max) / x_delta,
                                       broadcast=True)
                y = tf.where(x_idx_unclipped > ny - 1,
                             y_n1 + x_factor * (y_n1 - y_n2), y)
            else:
                y = tf.where(x_idx_unclipped > ny - 1, fill_value_above, y)

        return y
Ejemplo n.º 27
0
def data_postprocess(x):
    """Postprocess the samples from the model before plotting."""
    return tf.cast(
        tf.clip_by_value(tf.floor((x + .5) * IMAGE_BINS), 0, IMAGE_BINS - 1),
        tf.uint8)