コード例 #1
0
def _compute_range_weights(guide, grid_shape):
  """Computes range weights for the given guide image and grid shape.

  Args:
    guide: The guide image with shape (h, w).
    grid_shape: The grid shape, an array-like containing [gh, gw, gd, gc].

  Returns:
    An (image_extent, grid_extent) array with the spatial weight for each
    spatial and grid position.
  """
  guide_padded = _symmetric_pad_ij(guide, grid_shape)

  # Rescale `image` from [0, 1] to [0, grid_depth].
  # These are the floating point k coordinates of each sample.
  grid_depth = grid_shape[2]
  gk_float = guide_padded * grid_depth

  # Each sample with float value kf can splat onto locations:
  # k0 = floor(kf - 0.5)
  # k1 = ceil(kf - 0.5)
  #
  # The subtraction by 0.5 is necessary:
  # - Grid samples are located at half-integer coordinates:
  #   k = 0 places its sample at kf = 0.5.
  # - If kf = 1.4, the tent weight function is nonzero in the range [0.4, 1.4].
  #   Therefore, we need to splat to k0 = 0 and k1 = 1.
  # - If kf = 1.9, the tent weight function is nonzero in the range [0.9, 1.9].
  #   Therefore, we need to splat to k0 = 1 and k1 = 2.
  gk_floor = jnp.floor(gk_float - 0.5)
  gk_ceil = jnp.ceil(gk_float - 0.5)

  # Compute tent weights before clipping.
  wk_floor = smoothed_lerp_weight(gk_floor + 0.5, gk_float)
  wk_ceil = smoothed_lerp_weight(gk_ceil + 0.5, gk_float)

  # Cast to int for indexing.
  gk_floor = gk_floor.astype(jnp.int32)
  gk_ceil = gk_ceil.astype(jnp.int32)

  # Handle boundary conditions:
  # - Set the weight to 0 where the tent weight is positive but outside
  #   [0, grid_depth].
  # - Set the weight to 1 where the sample is between [0, 0.5) and
  #   (depth - 0.5, depth].
  wk_floor = jnp.where((gk_ceil == 0) & (gk_float < 0.5), 0, wk_floor)
  wk_ceil = jnp.where(
      (gk_floor == grid_depth - 1) & (gk_float > grid_depth - 0.5), 0, wk_ceil)
  wk_ceil = jnp.where((gk_ceil == 0) & (gk_float < 0.5), 1, wk_ceil)
  wk_floor = jnp.where(
      (gk_floor == grid_depth - 1) & (gk_float > grid_depth - 0.5), 1, wk_floor)

  # Now clip int coordinates for splatting. Coordinates outside [0, grid_depth)
  # will have zero weight so splatting to them does nothing.
  gk_floor_clipped = gk_floor.clip(0, grid_depth - 1)
  gk_ceil_clipped = gk_ceil.clip(0, grid_depth - 1)

  # Compute the i and j indices where we want to splat the weights wk with +=.
  # grid[ii, jj, gk_floor] += wk_floor
  # grid[ii, jj, gk_ceil] += wk_ceil
  ii, jj = jnp.meshgrid(
      jnp.arange(guide_padded.shape[0]),
      jnp.arange(guide_padded.shape[1]),
      indexing='ij')

  range_weights = jnp.zeros(
      (guide_padded.shape[0], guide_padded.shape[1], grid_depth))
  range_weights = jax.ops.index_add(range_weights,
                                    jax.ops.index[ii, jj,
                                                  gk_floor_clipped], wk_floor)
  range_weights = jax.ops.index_add(range_weights,
                                    jax.ops.index[ii, jj,
                                                  gk_ceil_clipped], wk_ceil)

  return range_weights
コード例 #2
0
ファイル: discrete.py プロジェクト: tbsexton/numpyro
 def sample(self, key, sample_shape=()):
     key_bern, key_poisson = random.split(key)
     shape = sample_shape + self.batch_shape
     mask = random.bernoulli(key_bern, self.gate, shape)
     samples = random.poisson(key_poisson, device_put(self.rate), shape)
     return jnp.where(mask, 0, samples)
コード例 #3
0
ファイル: decode.py プロジェクト: shafiahmed/flax
def beam_search(inputs,
                cache,
                tokens_to_logits,
                beam_size=4,
                alpha=0.6,
                eos_id=EOS_ID,
                max_decode_len=None):
    """Beam search for transformer machine translation.

  Args:
    inputs: array: [batch_size, length] int32 sequence of tokens.
    cache: flax attention cache.
    tokens_to_logits: fast autoregressive decoder function taking single token
      slices and cache and returning next-token logits and updated cache.
    beam_size: int: number of beams to use in beam search.
    alpha: float: scaling factor for brevity penalty.
    eos_id: int: id of end-of-sentence token for target vocabulary.
    max_decode_len: int: maximum length of decoded translations.

  Returns:
     Tuple of:
       [batch_size, beam_size, max_decode_len] top-scoring sequences
       [batch_size, beam_size] beam-search scores.
  """
    # We liberally annotate shape information for clarity below.

    batch_size = inputs.shape[0]
    if max_decode_len is None:
        max_decode_len = inputs.shape[1]
    end_marker = jnp.array(eos_id)

    # initialize beam search state
    beam_search_init_state = beam_init(batch_size, beam_size, max_decode_len,
                                       cache)

    def beam_search_loop_cond_fn(state):
        """Beam search loop termination condition."""
        # Have we reached max decoding length?
        not_at_end = (state.cur_index < max_decode_len - 1)

        # Is no further progress in the beam search possible?
        # Get the best possible scores from alive sequences.
        min_brevity_penalty = brevity_penalty(alpha, max_decode_len)
        best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
        # Get the worst scores from finished sequences.
        worst_finished_scores = jnp.min(state.finished_scores,
                                        axis=1,
                                        keepdims=True)
        # Mask out scores from slots without any actual finished sequences.
        worst_finished_scores = jnp.where(state.finished_flags,
                                          worst_finished_scores, NEG_INF)
        # If no best possible live score is better than current worst finished
        # scores, the search cannot improve the finished set further.
        search_terminated = jnp.all(worst_finished_scores > best_live_scores)

        # If we're not at the max decode length, and the search hasn't terminated,
        # continue looping.
        return not_at_end & (~search_terminated)

    def beam_search_loop_body_fn(state):
        """Beam search loop state update function."""
        # Collect the current position slice along length to feed the fast
        # autoregressive decoder model.  Flatten the beam dimension into batch
        # dimension for feeding into the model.
        # --> [batch * beam, 1]
        flat_ids = flatten_beam_dim(
            lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index),
                              (batch_size, beam_size, 1)))
        # Flatten beam dimension into batch to be compatible with model.
        # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
        flat_cache = jax.tree_map(flatten_beam_dim, state.cache)

        # Call fast-decoder model on current tokens to get next-position logits.
        # --> [batch * beam, vocab]
        flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache)

        # unflatten beam dimension
        # [batch * beam, vocab] --> [batch, beam, vocab]
        logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
        # Unflatten beam dimension in attention cache arrays
        # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
        new_cache = jax.tree_map(
            lambda x: unflatten_beam_dim(x, batch_size, beam_size),
            new_flat_cache)

        # Gather log probabilities from logits
        candidate_log_probs = jax.nn.log_softmax(logits)
        # Add new logprobs to existing prefix logprobs.
        # --> [batch, beam, vocab]
        log_probs = (candidate_log_probs +
                     jnp.expand_dims(state.live_logprobs, axis=2))

        # We'll need the vocab size, gather it from the log probability dimension.
        vocab_size = log_probs.shape[2]

        # Each item in batch has beam_size * vocab_size candidate sequences.
        # For each item, get the top 2*k candidates with the highest log-
        # probabilities. We gather the top 2*K beams here so that even if the best
        # K sequences reach EOS simultaneously, we have another K sequences
        # remaining to continue the live beam search.
        beams_to_keep = 2 * beam_size
        # Flatten beam and vocab dimensions.
        flat_log_probs = log_probs.reshape(
            (batch_size, beam_size * vocab_size))
        # Gather the top 2*K scores from _all_ beams.
        # --> [batch, 2*beams], [batch, 2*beams]
        topk_log_probs, topk_indices = lax.top_k(flat_log_probs,
                                                 k=beams_to_keep)
        # Recover the beam index by floor division.
        topk_beam_indices = topk_indices // vocab_size
        # Gather 2*k top beams.
        # --> [batch, 2*beams, length]
        topk_seq = gather_beams(state.live_seqs, topk_beam_indices, batch_size,
                                beams_to_keep)

        # Append the most probable 2*K token IDs to the top 2*K sequences
        # Recover token id by modulo division and expand Id array for broadcasting.
        # --> [batch, 2*beams, 1]
        topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
        # Update sequences for the 2*K top-k new sequences.
        # --> [batch, 2*beams, length]
        topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids,
                                            (0, 0, state.cur_index + 1))

        # Update LIVE (in-progress) sequences:
        # Did any of these sequences reach an end marker?
        # --> [batch, 2*beams]
        newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
        # To prevent these newly finished sequences from being added to the LIVE
        # set of active beam search sequences, set their log probs to a very large
        # negative value.
        new_log_probs = topk_log_probs + newly_finished * NEG_INF
        # Determine the top k beam indices (from top 2*k beams) from log probs.
        # --> [batch, beams]
        _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size)
        new_topk_indices = jnp.flip(new_topk_indices, axis=1)
        # Gather the top k beams (from top 2*k beams).
        # --> [batch, beams, length], [batch, beams]
        top_alive_seq, top_alive_log_probs = gather_beams(
            [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size)

        # Determine the top k beam indices from the original set of all beams.
        # --> [batch, beams]
        top_alive_indices = gather_beams(topk_beam_indices, new_topk_indices,
                                         batch_size, beam_size)
        # With these, gather the top k beam-associated caches.
        # --> {[batch, beams, ...], ...}
        top_alive_cache = gather_beams(new_cache, top_alive_indices,
                                       batch_size, beam_size)

        # Update FINISHED (reached end of sentence) sequences:
        # Calculate new seq scores from log probabilities.
        new_scores = topk_log_probs / brevity_penalty(alpha,
                                                      state.cur_index + 1)
        # Mask out the still unfinished sequences by adding large negative value.
        # --> [batch, 2*beams]
        new_scores += (~newly_finished) * NEG_INF

        # Combine sequences, scores, and flags along the beam dimension and compare
        # new finished sequence scores to existing finished scores and select the
        # best from the new set of beams.
        finished_seqs = jnp.concatenate(  # --> [batch, 3*beams, length]
            [state.finished_seqs, topk_seq],
            axis=1)
        finished_scores = jnp.concatenate(  # --> [batch, 3*beams]
            [state.finished_scores, new_scores],
            axis=1)
        finished_flags = jnp.concatenate(  # --> [batch, 3*beams]
            [state.finished_flags, newly_finished],
            axis=1)
        # --> [batch, beams, length], [batch, beams], [batch, beams]
        top_finished_seq, top_finished_scores, top_finished_flags = (
            gather_topk_beams([finished_seqs, finished_scores, finished_flags],
                              finished_scores, batch_size, beam_size))

        return BeamState(cur_index=state.cur_index + 1,
                         live_logprobs=top_alive_log_probs,
                         finished_scores=top_finished_scores,
                         live_seqs=top_alive_seq,
                         finished_seqs=top_finished_seq,
                         finished_flags=top_finished_flags,
                         cache=top_alive_cache)

    # Run while loop and get final beam search state.
    final_state = lax.while_loop(beam_search_loop_cond_fn,
                                 beam_search_loop_body_fn,
                                 beam_search_init_state)

    # Account for the edge-case where there are no finished sequences for a
    # particular batch item. If so, return live sequences for that batch item.
    # --> [batch]
    none_finished = jnp.any(final_state.finished_flags, axis=1)
    # --> [batch, beams, length]
    finished_seqs = jnp.where(none_finished[:, None, None],
                              final_state.finished_seqs, final_state.live_seqs)
    # --> [batch, beams]
    finished_scores = jnp.where(none_finished[:, None],
                                final_state.finished_scores,
                                final_state.live_logprobs)

    return finished_seqs, finished_scores
コード例 #4
0
ファイル: _glm.py プロジェクト: berenslab/RFEst
 def leaky_relu(_x):
     return jnp.where(_x > 0., _x, _x * 0.01)
コード例 #5
0
def _vectorized_cond(pred: Array, fn: Callable[[Array], Array],
                     operand: Array) -> Array:
    masked = jnp.where(pred, operand, 1)
    return jnp.where(pred, fn(masked), 0)
コード例 #6
0
ファイル: evaluators.py プロジェクト: lilujunai/dnet
def binary_crossentropy(outputs: tensor.array, targets: tensor.array) -> float:
    output_labels: tensor.array = tensor.where(outputs > 0.50, 1.0, 0.0)
    return tensor.mean(output_labels == targets)
コード例 #7
0
    def predict_fn(
        t: ArrayOrScalar = None,
        fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0.,
        fx_test_0: ArrayOrScalar = None,
        k_test_train: np.ndarray = None
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]:
        """Return output predictions on train [and test] set[s] at time[s] `t`.

    Args:
      t:
        a scalar or array of scalars of any shape in strictly increasing order.
        `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of
        training steps (but can be fractional).
      fx_train_or_state_0:
        either (a) output of the network at `t == 0` on the training set or (b)
        complete ODE state (`predict.ODEState`). Pass an ODE state if you want
        to operate on the full ODE state instead of output variables only
        (useful for inspecting auxiliary variables or resuming an optimizer with
        auxiliary variables from a specific state. Note that only
        `momentum != None` optimizer currently has auxiliary variables. To
        initialize an ODE state from scratch, call
        `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an
        ODE state is returned. `fx_train_0=None` means to not compute
        predictions on the training set.
      fx_test_0:
        output of the network at `t == 0` on the test set. `fx_test_0=None`
        means to not compute predictions on the test set.
      k_test_train:
        kernel relating test data with training data. Must have the shape of
        `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass
        `k_test_train=None` if you only need predictions on the training set.

    Returns:
      `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with
      potentially additional leading time dimensions matching `t.shape`.
      Alternatively can return an `ODEState` at time[s] `t`.

    Raises:
      ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`.
    """
        _check_inputs(fx_train_or_state_0, fx_test_0, k_test_train)

        t = np.array(t if t is not None else np.inf, dtype) * learning_rate
        t_shape = t.shape
        t = t.reshape((-1, ))

        # ODE solver requires `t[0]` to be the time where `fx_train_0` [and
        # `fx_test_0`] are evaluated, but also a strictly increasing sequence of
        # timesteps, so we always temporarily append an [almost] `0` at the start.
        t0 = np.where(t[0] == 0, np.full((1, ), -1e-24, t.dtype),
                      np.zeros((1, ), t.dtype))
        t = np.concatenate([t0, t])

        # Solve the ODE.
        fx_test_shape = _get_fx_test_shape(y_train, k_test_train, trace_axes)
        state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape)
        state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t)

        # Remove the added `t0`.
        trim = lambda x: x[1:].reshape(t_shape + x.shape[1:])
        trim_tree = lambda tree: tree_map(trim, tree)
        state_t = trim_tree(state_t)

        # `ODEState` -> `ODEState`
        if isinstance(fx_train_or_state_0, ODEState):
            return state_t

        # `np.ndarray` -> `np.ndarray`
        fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test

        if fx_train_or_state_0 is not None and fx_test_0 is None:
            return fx_train_t
        if fx_test_0 is not None and fx_train_or_state_0 is None:
            return fx_test_t
        return fx_train_t, fx_test_t
コード例 #8
0
ファイル: continuous.py プロジェクト: dsheldon/numpyro
 def mean(self):
     # for df <= 1. should be np.nan (keeping np.inf for consistency with scipy)
     return np.broadcast_to(np.where(self.df <= 1, np.inf, self.loc),
                            self.batch_shape)
コード例 #9
0
ファイル: continuous.py プロジェクト: dsheldon/numpyro
 def variance(self):
     var = np.where(self.df > 2, self.scale**2 * self.df / (self.df - 2.0),
                    np.inf)
     var = np.where(self.df <= 1, np.nan, var)
     return np.broadcast_to(var, self.batch_shape)
コード例 #10
0
 def wrapper(self, *args):
     log_prob = logpdf(self, *args)
     value = args[0]
     mask = self.support(value)
     log_prob = np.where(mask, log_prob, -np.inf)
     return log_prob
コード例 #11
0
def cal_inds(r, _1, _2, N):
    return np.where((_1 < r) * (_2 > r), np.arange(N), N)
コード例 #12
0
 def log_prob(self, value):
     probs = jnp.where((self.probs == 1) & (value == 0), 0, self.probs)
     return value * jnp.log1p(-probs) + jnp.log(probs)
コード例 #13
0
 def log_prob(self, value):
     log_prob = jnp.log1p(-self.gate) + self.base_dist.log_prob(value)
     return jnp.where(value == 0, jnp.log(self.gate + jnp.exp(log_prob)),
                      log_prob)
コード例 #14
0
def eval_step(model, batch):
    """Calculate evaluation metrics on a batch."""
    inputs, targets = batch['inputs'], batch['targets']
    weights = jnp.where(targets > 0, 1.0, 0.0)
    logits = model(inputs, train=False)
    return compute_metrics(logits, targets, weights)
コード例 #15
0
ファイル: nonlinearities.py プロジェクト: berenslab/RFEst
def elu(x):
    return jnp.where(x > 0, x, jnp.exp(x) - 1)
コード例 #16
0
ファイル: continuous.py プロジェクト: dsheldon/numpyro
 def mean(self):
     # mean is inf for alpha <= 1
     a = self.rate / (self.concentration - 1)
     return np.where(self.concentration <= 1, np.inf, a)
コード例 #17
0
 def _convert(x):
   return jnp.where(jnp.isfinite(x), x, jnp.nan)
コード例 #18
0
ファイル: continuous.py プロジェクト: dsheldon/numpyro
 def variance(self):
     # var is inf for alpha <= 2
     a = (self.rate /
          (self.concentration - 1))**2 / (self.concentration - 2)
     return np.where(self.concentration <= 2, np.inf, a)
コード例 #19
0
    def solve_max(
        self,
        inner_dual_vars: Any,
        opt_instance: InnerVerifInstance,
        key: jnp.array,
        step: int,
    ) -> jnp.array:
        """Solve maximization problem of opt_instance in closed form.

    Args:
      inner_dual_vars: Dual variables for the inner maximisation.
      opt_instance: Verification instance that defines optimization problem to
        be solved.
      key: Jax PRNG key.
      step: outer optimization iteration number

    Returns:
      max_value: final value of the objective function found.
    """
        if opt_instance.affine_before_relu:
            raise ValueError(
                'LPStratgey requires affine_before_relu to be False.')

        if not opt_instance.same_lagrangian_form_pre_post:
            raise ValueError(
                'Different lagrangian forms on inputs and outputs not'
                'supported')

        if (isinstance(opt_instance.lagrangian_form_pre, lag_form.Linear)
                or isinstance(opt_instance.lagrangian_form_post,
                              lag_form.Linear)):
            pass
        else:
            raise ValueError('LpStrategy cannot use Lagrangian form of type '
                             f'{type(opt_instance.lagrangian_form_pre)}.')

        # some renaming to simplify variable names
        affine_fn, = opt_instance.affine_fns
        bounds = opt_instance.bounds
        duals_pre = opt_instance.lagrange_params_pre
        if (opt_instance.is_last and opt_instance.spec_type
                == verify_utils.SpecType.ADVERSARIAL):
            # No duals_post for last layer, and objective folded in.
            batch_size = bounds[0].lb.shape[0]
            duals_post = jnp.ones([batch_size])
        else:
            duals_post = opt_instance.lagrange_params_post

        if opt_instance.is_first:
            # no "pre-activation" for input of first layer
            lb = bounds[0].lb
            ub = bounds[0].ub
        else:
            lb = bounds[0].lb_pre
            ub = bounds[0].ub_pre

        zero_inputs = jnp.zeros_like(lb)
        affine_constant = affine_fn(zero_inputs)
        duals_post = jnp.reshape(duals_post, affine_constant.shape)

        post_slope_x = jax.grad(lambda x: jnp.sum(affine_fn(x) * duals_post))(
            zero_inputs)

        if opt_instance.is_first:
            # find max element-wise (separable problem): either at lower bound or
            # upper bound -- no duals_pre for first layer
            max_per_element = jnp.maximum(
                post_slope_x * lb,
                post_slope_x * ub,
            )
        else:
            # find max element-wise (separable problem): either at lower bound, 0 or
            # upper bound
            duals_pre = jnp.reshape(duals_pre, lb.shape)
            max_per_element_bounds = jnp.maximum(
                post_slope_x * jax.nn.relu(lb) - duals_pre * lb,
                post_slope_x * jax.nn.relu(ub) - duals_pre * ub)
            max_per_element = jnp.where(
                jnp.logical_and(lb <= 0, ub >= 0),
                jax.nn.relu(
                    max_per_element_bounds),  # include zero where feasible
                max_per_element_bounds)  # otherwise only at boundaries
        # sum over coordinates and add constant term (does not change max choice)
        max_value = jnp.sum(max_per_element,
                            axis=tuple(range(1, max_per_element.ndim)))
        constant_per_element = affine_constant * duals_post
        constant = jnp.sum(constant_per_element,
                           axis=tuple(range(1, constant_per_element.ndim)))
        return max_value + constant
コード例 #20
0
ファイル: continuous.py プロジェクト: dsheldon/numpyro
 def mean(self):
     # mean is inf for alpha <= 1
     a = lax.div(self.alpha * self.scale, (self.alpha - 1))
     return np.where(self.alpha <= 1, np.inf, a)
コード例 #21
0
ファイル: _glm.py プロジェクト: berenslab/RFEst
 def relu(_x):
     return jnp.where(_x > 0., _x, 1e-7)
コード例 #22
0
ファイル: continuous.py プロジェクト: dsheldon/numpyro
 def variance(self):
     # var is inf for alpha <= 2
     a = lax.div((self.scale**2) * self.alpha,
                 (self.alpha - 1)**2 * (self.alpha - 2))
     return np.where(self.alpha <= 2, np.inf, a)
コード例 #23
0
def generate_triplets(key,
                      inputs,
                      n_inliers,
                      n_outliers,
                      n_random,
                      weight_temp=0.5,
                      distance='euclidean',
                      verbose=False):
    """Generate triplets.

  Args:
    key: Random key.
    inputs: Input points.
    n_inliers: Number of inliers.
    n_outliers: Number of outliers.
    n_random: Number of random triplets per point.
    weight_temp: Temperature of the log transformation on the weights.
    distance: Distance type.
    verbose: Whether to print progress.

  Returns:
    triplets and weights
  """
    n_points = inputs.shape[0]
    n_extra = min(n_inliers + 50, n_points)
    index = pynndescent.NNDescent(inputs, metric=distance)
    index.prepare()
    neighbors = index.query(inputs, n_extra)[0]
    neighbors = np.concatenate(
        (np.arange(n_points).reshape([-1, 1]), neighbors), 1)
    if verbose:
        logging.info('found nearest neighbors')
    distance_fn = get_distance_fn(distance)
    # conpute scaled neighbors and the scale parameter
    knn_distances, neighbors, sig = find_scaled_neighbors(
        inputs, neighbors, distance_fn)
    neighbors = neighbors[:, :n_inliers + 1]
    knn_distances = knn_distances[:, :n_inliers + 1]
    key, use_key = random.split(key)
    triplets = sample_knn_triplets(use_key, neighbors, n_inliers, n_outliers)
    weights = find_triplet_weights(inputs,
                                   triplets,
                                   neighbors[:, 1:n_inliers + 1],
                                   distance_fn,
                                   sig,
                                   distances=knn_distances[:, 1:n_inliers + 1])
    flip = weights < 0
    anchors, pairs = triplets[:, 0].reshape([-1, 1]), triplets[:, 1:]
    pairs = jnp.where(jnp.tile(flip.reshape([-1, 1]), [1, 2]),
                      jnp.fliplr(pairs), pairs)
    triplets = jnp.concatenate((anchors, pairs), 1)

    if n_random > 0:
        key, use_key = random.split(key)
        rand_triplets, rand_weights = sample_random_triplets(
            use_key, inputs, n_random, distance_fn, sig)

        triplets = jnp.concatenate((triplets, rand_triplets), 0)
        weights = jnp.concatenate((weights, 0.1 * rand_weights))

    weights -= jnp.min(weights)
    weights = tempered_log(1. + weights, weight_temp)
    return triplets, weights
コード例 #24
0
ファイル: nonlinearities.py プロジェクト: berenslab/RFEst
def relu(x):
    return jnp.where(x > 0., x, 0.)
コード例 #25
0
    def sample_kernel(sa_state, model_args=(), model_kwargs=None):
        pe_fn = potential_fn
        if potential_fn_gen:
            pe_fn = potential_fn_gen(*model_args, **model_kwargs)
        zs, pes, loc, scale = sa_state.adapt_state
        # we recompute loc/scale after each iteration to avoid precision loss
        # XXX: consider to expose a setting to do this job periodically
        # to save some computations
        loc = jnp.mean(zs, 0)
        if scale.ndim == 2:
            cov = jnp.cov(zs, rowvar=False, bias=True)
            if cov.shape == ():  # JAX returns scalar for 1D input
                cov = cov.reshape((1, 1))
            cholesky = jnp.linalg.cholesky(cov)
            scale = jnp.where(jnp.any(jnp.isnan(cholesky)), scale, cholesky)
        else:
            scale = jnp.std(zs, 0)

        rng_key, rng_key_z, rng_key_reject, rng_key_accept = random.split(
            sa_state.rng_key, 4)
        _, unravel_fn = ravel_pytree(sa_state.z)

        z = loc + _sample_proposal(scale, rng_key_z)
        pe = pe_fn(unravel_fn(z))
        pe = jnp.where(jnp.isnan(pe), jnp.inf, pe)
        diverging = (pe - sa_state.potential_energy) > max_delta_energy

        # NB: all terms having the pattern *s will have shape N x ...
        # and all terms having the pattern *s_ will have shape (N + 1) x ...
        locs, scales = _get_proposal_loc_and_scale(zs, loc, scale, z)
        zs_ = jnp.concatenate([zs, z[None, :]])
        pes_ = jnp.concatenate([pes, pe[None]])
        locs_ = jnp.concatenate([locs, loc[None, :]])
        scales_ = jnp.concatenate([scales, scale[None, ...]])
        if scale.ndim == 2:  # dense_mass
            log_weights_ = dist.MultivariateNormal(
                locs_, scale_tril=scales_).log_prob(zs_) + pes_
        else:
            log_weights_ = dist.Normal(locs_,
                                       scales_).log_prob(zs_).sum(-1) + pes_
        log_weights_ = jnp.where(jnp.isnan(log_weights_), -jnp.inf,
                                 log_weights_)
        # get rejecting index
        j = random.categorical(rng_key_reject, log_weights_)
        zs = _numpy_delete(zs_, j)
        pes = _numpy_delete(pes_, j)
        loc = locs_[j]
        scale = scales_[j]
        adapt_state = SAAdaptState(zs, pes, loc, scale)

        # NB: weights[-1] / sum(weights) is the probability of rejecting the new sample `z`.
        accept_prob = 1 - jnp.exp(log_weights_[-1] - logsumexp(log_weights_))
        itr = sa_state.i + 1
        n = jnp.where(sa_state.i < wa_steps, itr, itr - wa_steps)
        mean_accept_prob = sa_state.mean_accept_prob + (
            accept_prob - sa_state.mean_accept_prob) / n

        # XXX: we make a modification of SA sampler in [1]
        # in [1], each MCMC state contains N points `zs`
        # here we do resampling to pick randomly a point from those N points
        k = random.categorical(rng_key_accept, jnp.zeros(zs.shape[0]))
        z = unravel_fn(zs[k])
        pe = pes[k]
        return SAState(itr, z, pe, accept_prob, mean_accept_prob, diverging,
                       adapt_state, rng_key)
コード例 #26
0
ファイル: nonlinearities.py プロジェクト: berenslab/RFEst
def leaky_relu(x):
    return jnp.where(x > 0., x, x * 0.01)
コード例 #27
0
ファイル: discrete.py プロジェクト: tbsexton/numpyro
 def log_prob(self, value):
     log_prob = jnp.log(self.rate) * value - gammaln(value + 1) + (jnp.log1p(-self.gate) - self.rate)
     return jnp.where(value == 0, jnp.logaddexp(jnp.log(self.gate), log_prob), log_prob)
コード例 #28
0
ファイル: nonlinearities.py プロジェクト: berenslab/RFEst
def selu(x):
    return 1.0507 * jnp.where(x > 0., x, 1.6733 * jnp.exp(x) - 1.6733)
コード例 #29
0
ファイル: metrics.py プロジェクト: alshedivat/fedjax
def apply_mask(mask: jnp.ndarray, a: jnp.ndarray,
               b: jnp.ndarray) -> jnp.ndarray:
    """Applies mask on the leading dimension."""
    rank = max(len(a.shape), len(b.shape))
    return jnp.where(jnp.expand_dims(mask, tuple(range(1, rank))), a, b)
コード例 #30
0
 def dist_sq(R):
     dR = R[:, np.newaxis, :] - R[np.newaxis, :, :]
     zero = np.zeros_like(dR)
     dR = dR - np.where(np.abs(dR) < 0.5, zero, 0.5 * np.sign(dR))
     return np.sum(dR**2, axis=2)