예제 #1
0
 def loop_body(val):
   """From ids at `step`, update output ids at `step + 1`."""
   step = val.step
   decoder_state, logits = extend_step_fn(val.state, val.output_ids[:, step])
   logprobs = jax.nn.log_softmax(logits.astype(jnp.float32))
   val.state = decoder_state
   # When step becomes prefix_length - 1, the new output has index beyond
   # the known prefix.
   # If prefix_length is 0, the condition is always False, so we take the
   # decoded output rather than the prefix.
   new_ids = jnp.where(step < prefix_lengths - 1, target_ids[:, step + 1],
                       jnp.argmax(logits, axis=1))
   prev_done = val.done
   new_ids = jnp.where(prev_done, jnp.zeros_like(new_ids), new_ids)
   if eos_id is not None:
     val.done = jnp.logical_or(prev_done, jnp.equal(new_ids, eos_id))
   max_decoding_steps_reached = (jnp.ones_like(prefix_lengths) * (step + 2) -
                                 prefix_lengths) >= max_decode_steps
   val.done = jnp.logical_or(val.done, max_decoding_steps_reached)
   done_at_this_step = jnp.logical_and(jnp.logical_not(prev_done), val.done)
   val.decode_lengths = jnp.where(
       done_at_this_step,
       jnp.ones_like(val.decode_lengths) * (step + 2), val.decode_lengths)
   val.output_ids = val.output_ids.at[:, step + 1].set(new_ids)
   logprobs_at_new_ids = logprobs.at[jnp.arange(batch_size), new_ids].get()
   logprobs_at_new_ids = jnp.where(prev_done,
                                   jnp.ones_like(logprobs_at_new_ids),
                                   logprobs_at_new_ids)
   val.logprobs = val.logprobs.at[:, step + 1].set(logprobs_at_new_ids)
   val.step += 1
   return val
예제 #2
0
def norm_logbicop_approx(u,v,rho):
    pu = jnp.where(u==0.5,x =0.5,y = ndtri_(u))
    pv = jnp.where(v==0.5,x =0.5,y = ndtri_(v))
    alpha_u = jnp.where(u==0.5,x =0,y = (1/jnp.sqrt(1-rho**2))*((pv/pu)-rho))
    alpha_v = jnp.where(v==0.5,x =0, y = (1/jnp.sqrt(1-rho**2))*((pu/pv)-rho))
    rho_u = -alpha_u /(jnp.sqrt(1+alpha_u**2))
    rho_v = -alpha_v /(jnp.sqrt(1+alpha_v**2))

    C_uu = jnp.exp(norm_logbicop_diag_approx(jnp.log(u),1-(2*rho_u**2)))
    C_vv = jnp.exp(norm_logbicop_diag_approx(jnp.log(v),1-(2*rho_v**2)))

    ind_rho_u = jnp.where(rho_u <0,x = 1,y = 0)
    C_half_u = 0.5*ind_rho_u*C_uu + (1-ind_rho_u)*(u-0.5*C_uu)

    ind_rho_v = jnp.where(rho_v <0,x = 1,y = 0)
    C_half_v = 0.5*ind_rho_v*C_vv + (1-ind_rho_v)*(v-0.5*C_vv)

    delta_uv = jnp.where(jnp.logical_or(jnp.logical_and(u<0.5,v>=0.5), jnp.logical_and(u>=0.5, v<0.5)), x = 0.5, y = 0)
    C_uv = C_half_u + C_half_v -delta_uv
    
    ind_v_half = jnp.where(u==0.5, x = 1, y = 0)
    ind_u_half = jnp.where(v==0.5, x = 1, y = 0)
    ind_uv_half_or = jnp.logical_or(u==0.5,v==0.5)
    ind_uv_half_and = jnp.logical_and(u==0.5,v==0.5)
    C_uv = (1-ind_uv_half_and)*ind_u_half*(u-0.5*jnp.exp(norm_logbicop_diag_approx(jnp.log(u),1-2*rho**2)))+ \
            (1-ind_uv_half_and)*ind_v_half*(v-0.5*jnp.exp(norm_logbicop_diag_approx(jnp.log(v),1-2*rho**2)))+\
            ind_uv_half_and*jnp.exp(norm_logbicop_diag_approx(jnp.log(u),rho)) +\
            (1-ind_uv_half_or)*C_uv
    return jnp.log(C_uv)
예제 #3
0
def has_converged(x: Tensor, grad: Tensor, l: Tensor, u: Tensor):
    stuck_at_lower = jnp.logical_and(x == l, grad >= 0)
    stuck_at_upper = jnp.logical_and(x == u, grad <= 0)
    zero_grad = grad == 0

    stuck_at_border = jnp.logical_or(stuck_at_lower, stuck_at_upper)
    converged = jnp.logical_or(stuck_at_border, zero_grad)
    return jnp.all(converged)
예제 #4
0
 def cond_func(val):
     q, i, norm_delta_q, error, = val
     diverged = np.logical_or(error > divergence_tol,
                              np.isnan(error))
     converged = np.logical_and(error < convergence_tol,
                                norm_delta_q < position_tol)
     return np.logical_not(
         np.logical_or((i >= max_iters),
                       np.logical_or(diverged, converged)))
예제 #5
0
 def compute_intersection_point(denom):
     t1 = np.cross(v2, v1) / denom
     t2 = (v1 @ v3) / denom
     condition = np.logical_or(np.logical_or(t1 < 0.0, t2 < 0.0), t2 > 1.0)
     return jax.lax.cond(
         condition,
         true_fun=lambda t: np.array([np.inf, np.inf]),
         false_fun=lambda t: ray_origin + t1 * ray_direction,
         operand=t1,
     )
예제 #6
0
    def adaptive_hmc_update(state,
                            log_prob,
                            state_grad,
                            key,
                            step_size,
                            trajectory_len,
                            target_accept_rate=0.8,
                            step_size_adaptation_speed=0.05,
                            max_n_leapfrog=1000,
                            jitter_amt=0.2):

        normal_key, uniform_key, jitter_key = jax.random.split(key, 3)

        n_leapfrog = jnp.array(jnp.ceil(trajectory_len / step_size), jnp.int32)
        n_leapfrog = jnp.minimum(n_leapfrog, max_n_leapfrog)
        jittered_step_size = step_size * jnp.exp(
            jnp.where(
                jnp.logical_or(step_size_adaptation_speed <= 0,
                               target_accept_rate <= 0),
                jnp.log(1. + jitter_amt) *
                (2 * jax.random.uniform(jitter_key, ()) - 1.), 0.))

        num_leaves = len(jax.tree_leaves(state))
        normal_keys = list(jax.random.split(normal_key, num_leaves))
        treedef = jax.tree_structure(state)
        normal_keys = jax.tree_unflatten(treedef, normal_keys)
        momentum = jax.tree_multimap(
            lambda s, key: jax.random.normal(key, s.shape), state, normal_keys)

        initial_energy = get_kinetic_energy(momentum) - log_prob
        new_state, new_momentum, new_grad, new_log_prob = leapfrog(
            jittered_step_size, n_leapfrog, state, momentum, state_grad)
        new_energy = _nan_to_inf(
            get_kinetic_energy(new_momentum) - new_log_prob)

        energy_diff = initial_energy - new_energy
        accept_prob = jnp.minimum(1., jnp.exp(energy_diff))
        # TODO(izmailovpavel): check why the second condition is needed.
        accepted = jnp.logical_and(
            jax.random.uniform(uniform_key, log_prob.shape) < accept_prob,
            jnp.isfinite(energy_diff))

        step_size = step_size * jnp.exp(
            jnp.where(
                jnp.logical_or(target_accept_rate <= 0,
                               step_size_adaptation_speed <= 0), 0.,
                step_size_adaptation_speed *
                (jnp.mean(accept_prob) - target_accept_rate)))

        state = jax.lax.cond(accepted, _first, _second, (new_state, state))
        log_prob = jnp.where(accepted, new_log_prob, log_prob)
        state_grad = jax.lax.cond(accepted, _first, _second,
                                  (new_grad, state_grad))
        return state, log_prob, state_grad, step_size, accept_prob
예제 #7
0
def cond_fun(maxiter, bound, feasStop, state):
    logging.info('compiling cond_fun')
    counter = state[1]
    dual_objective, primal_dual_gap, maxfeasible = state[7:10]

    cond1 = counter <= maxiter
    cond2 = dual_objective < bound
    cond3 = np.logical_or(
        np.absolute(primal_dual_gap) > 1e-6,
        np.logical_or(maxfeasible > feasStop,
                      np.logical_and(counter < 200, maxfeasible >= 0)))
    return np.logical_and(cond1, np.logical_and(cond2, cond3))
예제 #8
0
    def check_convergence(
        self,
        state_prev: _NonlinearSolverState,
        cost_updated: hints.Scalar,
        local_delta_assignments: VariableAssignments,
        negative_gradient: hints.Array,
    ) -> bool:
        """Check for convergence!"""

        # Cost tolerance
        converged_cost = (
            jnp.abs(cost_updated - state_prev.cost) / state_prev.cost
            < self.cost_tolerance
        )

        # Gradient tolerance
        converged_gradient = jnp.where(
            state_prev.iterations >= self.gradient_tolerance_start_step,
            jnp.max(
                state_prev.assignments.storage
                - state_prev.assignments.manifold_retract(
                    VariableAssignments(
                        storage=negative_gradient,
                        storage_metadata=local_delta_assignments.storage_metadata,
                    ),
                ).storage
            )
            < self.gradient_tolerance,
            False,
        )

        # Parameter tolerance
        converged_parameters = (
            jnp.linalg.norm(jnp.abs(local_delta_assignments.storage))
            < (
                jnp.linalg.norm(state_prev.assignments.storage)
                + self.parameter_tolerance
            )
            * self.parameter_tolerance
        )

        return jnp.logical_or(
            converged_cost,
            jnp.logical_or(
                converged_gradient,
                converged_parameters,
            ),
        )
def _build_sliding_window_mask(window_size, global_mask):
    """Builds mask for sliding window pattern.

  Args:
    window_size: int, size of sliding window.
    global_mask: boolean jax array of shape `[batch_size, seq_len]`.

  Returns:
    mask, boolean jax array of shape `[batch_size, 1 (n_heads), seq_len,
    seq_len]`.

  If `window_size` is odd, both left and right sides have the same receptive
  field. Otherwise, the left side gets one more. Note - we need global mask
  because
  due to the symmetry requirement, non-global positions can still attend to
  global positions.
  """
    seq_len = global_mask.shape[1]
    right_size = window_size // 2
    left_size = window_size - right_size
    left_mask = sum(np.eye(seq_len, k=-i) for i in range(left_size))
    right_mask = sum(np.eye(seq_len, k=i) for i in range(1, right_size + 1))
    mask = left_mask + right_mask
    mask = jnp.array(mask[np.newaxis, np.newaxis, :, :]).astype(jnp.bool_)
    return jnp.logical_or(mask, _build_global_mask(global_mask))
예제 #10
0
  def cond_fn(iteration, const, state):
    threshold = const[-1]
    errors = state[0]
    err = errors[iteration // inner_iterations-1, 0]

    return jnp.logical_or(iteration == 0,
                          jnp.logical_and(jnp.isfinite(err), err > threshold))
예제 #11
0
    def __init__(self, n: jnp.ndarray, p: jnp.ndarray):
        """Initializes a multinomial distribution with n trials and probabilities p.

        n may be multidimensional, in which case it represents
        multiple multinomial distributions.

        p has to have the shape of n plus 1 dimension representing the
        the probabilities of each event. The probabilities in the last
        dimension have to sum to 1.

        Args:
            n: Number of trials. Has to be an integer and non-negative.
            p: Probabilities of trial successes. Must have same shape
                as n + 1 additional dimension representing the probabilities.
                Probabilities have to sum to 1.
        """
        super().__init__()

        if n.shape != p.shape[:len(n.shape)] or \
                len(n.shape) + 1 != len(p.shape):
            raise ValueError('Shapes of n and p not compatible')

        # we cannot raise a ValueError here since we get problems with
        # ConcretizationError during Metropolis-Hastings
        nans = jnp.full(p.shape, jnp.nan)
        self.p = jnp.where(jnp.logical_or(p < 0, p > 1), nans, p)
        self.p = jnp.where(jnp.isclose(jnp.sum(self.p, -1, keepdims=True), 1),
                           self.p, nans)

        nans = jnp.full(n.shape, jnp.nan)
        self.n = jnp.where(n <= 0, nans, n)
예제 #12
0
def adapt_step_size(step_size, target_accept_rate, accept_prob,
                    step_size_adaptation_speed):
    log_factor = jnp.where(
        jnp.logical_or(target_accept_rate <= 0,
                       step_size_adaptation_speed <= 0), 0.,
        step_size_adaptation_speed * (accept_prob - target_accept_rate))
    return step_size * jnp.exp(log_factor)
예제 #13
0
    def update(updates, state, params=None):
        inner_state = state.inner_state
        flat_updates = tree_flatten(updates)[0]
        isfinite = jnp.all(
            jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
        notfinite_count = jnp.where(isfinite, jnp.zeros([], jnp.int64),
                                    1 + state.notfinite_count)

        def do_update(_):
            return inner.update(updates, inner_state, params)

        def reject_update(_):
            return (tree_map(jnp.zeros_like, updates), inner_state)

        updates, new_inner_state = lax.cond(jnp.logical_or(
            isfinite, notfinite_count > max_consecutive_errors),
                                            do_update,
                                            reject_update,
                                            operand=None)

        return updates, ApplyIfFiniteState(
            notfinite_count=notfinite_count,
            last_finite=isfinite,
            total_notfinite=jnp.logical_not(isfinite) + state.total_notfinite,
            inner_state=new_inner_state)
예제 #14
0
    def apply_fun(params, inputs, **kwargs):
        def inner_loop_body(i, inputs_and_counter):
            inputs = inputs_and_counter[0]
            counter = inputs_and_counter[1]
            val = inputs[j, i]
            condition = jnp.logical_or(val < -0.5, val > 0.5)
            inputs, counter = jax.lax.cond(condition, lambda xTrue: (
            jax.ops.index_update(inputs, jax.ops.index[j, i], counter * val), counter * -1),
                                           lambda xFalse: (inputs, counter), (None))
            return inputs, counter

        for j in range(inputs.shape[0]):
            counter = +1
            #print('------------------')
            #print(inputs[j, :])
            #flips every second nonzero spin
            for i in range(inputs.shape[1]):
                val = inputs[j, i]
                condition = jnp.logical_or(val < -0.5, val > 0.5)
                inputs, counter = jax.lax.cond(condition, lambda xTrue: (jax.ops.index_update(inputs, jax.ops.index[j, i], counter*val), counter * -1), lambda xFalse: (inputs, counter), (None))
                # if(val < -0.5 or val > 0.5):
                #     inputs = jax.ops.index_update(inputs, jax.ops.index[j, i], counter*val)
                #     counter *= -1
            #inputs, counter = jax.lax.fori_loop(0, inputs.shape[1], inner_loop_body, (inputs, counter))
            #print(inputs[j, :])
        return inputs
예제 #15
0
 def sample_search_cond_fn(state):
     """state termination condition fn."""
     has_reached_max_length = state.cur_len == max_length
     all_sequence_finished = jnp.all(state.is_sent_finished)
     finish_generation = jnp.logical_or(has_reached_max_length,
                                        all_sequence_finished)
     return ~finish_generation
예제 #16
0
    def apply_param_gradient(self, step, hyper_params, param, state, grad):
        del step
        assert hyper_params.learning_rate is not None, 'no learning rate provided.'

        param_norm = jnp.linalg.norm(param)
        grad_norm = jnp.linalg.norm(grad)
        trust_ratio = hyper_params.trust_coefficient * param_norm / (
            grad_norm + hyper_params.weight_decay * param_norm +
            hyper_params.eps)
        clipped_trust_ratio = jnp.where(
            jnp.logical_or(grad_norm == 0., param_norm == 0.), 1., trust_ratio)
        scaled_lr = hyper_params.learning_rate * clipped_trust_ratio
        if hyper_params.weight_decay != 0:
            grad += hyper_params.weight_decay * param

        scaled_grad = scaled_lr * grad
        momentum = state.momentum
        new_momentum = hyper_params.beta * momentum + scaled_grad
        if hyper_params.nesterov:
            d_p = scaled_grad + hyper_params.beta * new_momentum
        else:
            d_p = new_momentum
        new_param = param - d_p
        new_state = _LARSParamState(new_momentum)
        return new_param, new_state
예제 #17
0
def _rational_quadratic_spline_fwd(x: Array, x_pos: Array, y_pos: Array,
                                   knot_slopes: Array) -> Tuple[Array, Array]:
    """Applies a rational-quadratic spline to a scalar.

  Args:
    x: a scalar (0-dimensional array). The scalar `x` can be any real number; it
      will be transformed by the spline if it's in the closed interval
      `[x_pos[0], x_pos[-1]]`, and it will be transformed linearly if it's
      outside that interval.
    x_pos: array of shape [num_bins + 1], the bin boundaries on the x axis.
    y_pos: array of shape [num_bins + 1], the bin boundaries on the y axis.
    knot_slopes: array of shape [num_bins + 1], the slopes at the knot points.
  Returns:
    A tuple of two scalars: the output of the transformation and the log of the
    absolute first derivative at `x`.
  """
    # Search to find the right bin. NOTE: The bins are sorted, so we could use
    # binary search, but this is more GPU/TPU friendly.
    i = jnp.sum(x_pos < x) - 1
    below_range = x <= x_pos[0]
    above_range = x >= x_pos[-1]
    outside_range = jnp.logical_or(below_range, above_range)
    # Avoid NaNs which might propagate to the gradient despite jnp.where
    # later by setting i to 0 when outside of range.
    i = jnp.where(outside_range, 0, i)
    bin_width = x_pos[i + 1] - x_pos[i]
    bin_height = y_pos[i + 1] - y_pos[i]
    bin_slope = bin_height / bin_width
    z = (x - x_pos[i]) / bin_width
    # `z` should be in range [0, 1] to avoid NaNs later. This can happen because
    # of small floating point issues or when x is outside of the range and `i` was
    # set to 0. To avoid all problems, we restrict z in [0, 1].
    z = jnp.clip(z, 0., 1.)
    sq_z = z * z
    z1mz = z - sq_z  # z(1-z)
    sq_1mz = (1. - z)**2
    slopes_term = knot_slopes[i + 1] + knot_slopes[i] - 2. * bin_slope
    numerator = bin_height * (bin_slope * sq_z + knot_slopes[i] * z1mz)
    denominator = bin_slope + slopes_term * z1mz
    y = y_pos[i] + numerator / denominator
    # Compute log det Jacobian.
    # The logdet is a sum of 3 logs. It is easy to see that the inputs of the
    # first two logs are guaranteed to be positive because we ensured that z is in
    # [0, 1]. This is also true of the log(denominator) because:
    # denominator
    # == bin_slope + (knot_slopes[i+1] + knot_slopes[i] - 2 * bin_slope) * z*(1-z)
    # >= bin_slope - 2 * bin_slope * z * (1-z)
    # >= bin_slope - 2 * bin_slope * (1/4)
    # == bin_slope / 2
    logdet = 2. * jnp.log(bin_slope) + jnp.log(
        knot_slopes[i + 1] * sq_z + 2. * bin_slope * z1mz +
        knot_slopes[i] * sq_1mz) - 2. * jnp.log(denominator)
    # If x is outside the spline range, we default to a linear transformation.
    y = jnp.where(below_range, (x - x_pos[0]) * knot_slopes[0] + x_pos[0], y)
    y = jnp.where(above_range, (x - x_pos[-1]) * knot_slopes[-1] + x_pos[-1],
                  y)
    logdet = jnp.where(below_range, jnp.log(knot_slopes[0]), logdet)
    logdet = jnp.where(above_range, jnp.log(knot_slopes[-1]), logdet)
    return y, logdet
예제 #18
0
 def take(tensor, idx, fill=0):
     # Non jit-friendly implementation
     # illegal = jnp.logical_or(idx > p,idx < 1)
     # return tensor[..., idx-1].at[..., illegal].set(fill)
     legalized_idx = jnp.clip(idx, a_min=1, a_max=p)
     illegal_mask = jnp.logical_or(idx > p, idx < 1)
     return (tensor[..., legalized_idx - 1] * (1 - 1 * illegal_mask) +
             illegal_mask * fill)
예제 #19
0
파일: randaugment.py 프로젝트: 4rtemi5/imax
def _randaugment_inner_for_loop(_, in_args):
    """
    Loop body for for randougment.
    Args:
        i: loop iteration
        in_args: loop body arguments

    Returns:
        updated loop arguments
    """
    (image, geometric_transforms, random_key, available_ops, op_probs,
     magnitude, cutout_const, translate_const, join_transforms,
     default_replace_value) = in_args
    random_keys = random.split(random_key, num=8)
    random_key = random_keys[0]  # keep for next iteration
    op_to_select = random.choice(random_keys[1], available_ops, p=op_probs)
    mask_value = jnp.where(default_replace_value > 0,
                           jnp.ones([image.shape[-1]]) * default_replace_value,
                           random.randint(random_keys[2],
                                          [image.shape[-1]],
                                          minval=-1, maxval=256))
    random_magnitude = random.uniform(random_keys[3], [], minval=0.,
                                      maxval=magnitude)
    cutout_mask = color_transforms.get_random_cutout_mask(
        random_keys[4],
        image.shape,
        cutout_const)

    translate_vals = (random.uniform(random_keys[5], [], minval=0.0,
                                     maxval=1.0) * translate_const,
                      random.uniform(random_keys[6], [], minval=0.0,
                                     maxval=1.0) * translate_const)
    negate = random.randint(random_keys[7], [], minval=0,
                            maxval=2).astype('bool')

    args = level_to_arg(cutout_mask, translate_vals, negate,
                        random_magnitude, mask_value)

    if DEBUG:
        print(op_to_select, args[op_to_select])

    image, geometric_transform = _apply_ops(image, args, op_to_select)

    image, geometric_transform = jax.lax.cond(
        jnp.logical_or(join_transforms, jnp.all(
            jnp.not_equal(geometric_transform, jnp.identity(4)))),
        lambda op: (op[0], op[1]),
        lambda op: (transforms.apply_transform(op[0],
                                               op[1],
                                               mask_value=mask_value),
                    jnp.identity(4)),
        (image, geometric_transform)
    )

    geometric_transforms = jnp.matmul(geometric_transforms, geometric_transform)
    return(image, geometric_transforms, random_key, available_ops, op_probs,
           magnitude, cutout_const, translate_const, join_transforms,
           default_replace_value)
예제 #20
0
 def inner_loop_body(i, inputs_and_counter):
     inputs = inputs_and_counter[0]
     counter = inputs_and_counter[1]
     val = inputs[j, i]
     condition = jnp.logical_or(val < -0.5, val > 0.5)
     inputs, counter = jax.lax.cond(condition, lambda xTrue: (
     jax.ops.index_update(inputs, jax.ops.index[j, i], counter * val), counter * -1),
                                    lambda xFalse: (inputs, counter), (None))
     return inputs, counter
예제 #21
0
def apply_cutout(key, image, cutoutwidth, cutoutheight):
  channels, width, height = image.shape
  x0, y0 = jax.random.randint(key, (2,), minval=0, maxval=width+1-cutoutwidth)

  # Construct a mask
  xx, yy = jnp.meshgrid(jnp.arange(width), jnp.arange(height))
  xmask = jnp.where(jnp.logical_and(xx >= x0, xx < x0+cutoutwidth), 0, 1)
  ymask = jnp.where(jnp.logical_and(yy >= y0, yy < y0+cutoutheight), 0, 1)
  mask = jnp.logical_or(xmask, ymask)
  return image * mask
def i_stimulus(t, params_and_data):
    return lax.cond(
        np.logical_or(
            np.less(t, params_and_data["stimulus_start_time"]),
            np.greater(t, params_and_data["stimulus_end_time"]),
        ),
        lambda _: 0.0,
        lambda _: params_and_data["i_stimulus"],
        None,
    )
예제 #23
0
def categorical_sample(key, probs):
    """Sample from a set of discrete probabilities."""
    probs = probs / probs.sum(axis=-1, keepdims=True)
    cpi = jnp.cumsum(probs, axis=-1)
    eps = jnp.finfo(probs.dtype).eps
    rnds = jax.random.uniform(key=key,
                              shape=probs.shape[:-1] + (1, ),
                              dtype=probs.dtype,
                              minval=eps)
    return jnp.argmin(jnp.logical_or(rnds > cpi, probs < eps), axis=-1)
예제 #24
0
    def _step(
        self,
        graph: "StackedFactorGraph",
        state_prev: NonlinearSolverState,
    ) -> NonlinearSolverState:
        """Linearize, solve linear subproblem, and update on manifold."""

        self._hcb_print(
            lambda i, max_i, cost: f"Iteration #{i}/{max_i}: cost={str(cost)}",
            i=state_prev.iterations,
            max_i=self.max_iterations,
            cost=state_prev.cost,
        )

        # Linearize graph
        A: sparse.SparseCooMatrix = graph.compute_whitened_residual_jacobian(
            assignments=state_prev.assignments,
            residual_vector=state_prev.residual_vector,
        )
        ATb = -(A.T @ state_prev.residual_vector)

        # Solve linear subproblem
        local_delta_assignments = VariableAssignments(
            storage=self.linear_solver.solve_subproblem(
                A=A,
                ATb=ATb,
                lambd=0.0,
                iteration=state_prev.iterations,
            ),
            storage_layout=graph.local_storage_layout,
        )

        # On-manifold retraction
        assignments = state_prev.assignments.manifold_retract(
            local_delta_assignments=local_delta_assignments, )

        # Check for convergence
        cost, residual_vector = graph.compute_cost(assignments)
        done = jnp.logical_or(
            self.check_exceeded_max_iterations(state_prev=state_prev),
            self.check_convergence(
                state_prev=state_prev,
                cost_updated=cost,
                local_delta_assignments=local_delta_assignments,
                negative_gradient=ATb,
            ),
        )

        return NonlinearSolverState(
            iterations=state_prev.iterations + 1,
            assignments=assignments,
            cost=cost,
            residual_vector=residual_vector,
            done=done,
        )
예제 #25
0
    def _scale_update(update, param):
      param_norm = jnp.linalg.norm(param)
      update_norm = jnp.linalg.norm(update)
      trust_ratio = param_norm / update_norm

      # Set trust_ratio to 1 in case where parameters would never be updated.
      zero_norm = jnp.logical_or(param_norm == 0., update_norm == 0.)
      safe_trust_ratio = jnp.where(
          zero_norm, jnp.array(1.0, dtype=param.dtype), trust_ratio)

      return update * safe_trust_ratio
예제 #26
0
 def body_fun(carry):
     key, samples, _ = carry
     key, use_key = random.split(key)
     new_samples = random.randint(use_key,
                                  shape=shape,
                                  minval=0,
                                  maxval=maxval)
     discard = jnp.logical_or(in1dvec(new_samples, samples),
                              in1dvec(new_samples, rejects))
     samples = jnp.where(discard, samples, new_samples)
     return key, samples, in1dvec(samples, rejects)
예제 #27
0
def create_mask(x, x_var, indices, val=0.0):
    mask = onp.ones(x.shape)
    # mask[x == 0] = val
    ind = np.where(np.logical_or(x_var <= 0, x_var == mask_std_val**2.0))
    mask[ind] = val
    fullindices = onp.asarray(indices)[:, None] + onp.arange(
        x.shape[1])[None, :]
    mask[fullindices < 0] = val
    # offs = - np.maximum(indices, np.zeros_like(indices))
    # for io, off in enumerate(offs):
    #    mask[io, 0:off] = 0
    return np.asarray(mask)
예제 #28
0
파일: train.py 프로젝트: us/flax
 def encode_step_fn(carry, x):
   lstm_state, is_eos = carry
   new_lstm_state, y = lstm_cell(lstm_state, x)
   # Pass forward the previous state if EOS has already been reached.
   def select_carried_state(new_state, old_state):
     return jnp.where(is_eos[:, np.newaxis], old_state, new_state)
   # LSTM state is a tuple (c, h).
   carried_lstm_state = tuple(
       select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
   # Update `is_eos`.
   is_eos = jnp.logical_or(is_eos, x[:, eos_id])
   return (carried_lstm_state, is_eos), y
def _build_global_mask(mask):
    """Builds mask for global attention pattern.

  Args:
    mask: boolean jax array of shape `[batch_size, seq_len]`.

  Returns:
    mask, boolean jax array of shape `[batch_size, 1 (n_heads), seq_len,
    seq_len]`.
  """
    return jnp.logical_or(mask[:, jnp.newaxis, :, jnp.newaxis],
                          mask[:, jnp.newaxis, jnp.newaxis, :])
예제 #30
0
def _rational_quadratic_spline_inv(y: Array, x_pos: Array, y_pos: Array,
                                   knot_slopes: Array) -> Tuple[Array, Array]:
    """Applies the inverse of a rational-quadratic spline to a scalar.

  Args:
    y: a scalar (0-dimensional array). The scalar `y` can be any real number; it
      will be transformed by the spline if it's in the closed interval
      `[y_pos[0], y_pos[-1]]`, and it will be transformed linearly if it's
      outside that interval.
    x_pos: array of shape [num_bins + 1], the bin boundaries on the x axis.
    y_pos: array of shape [num_bins + 1], the bin boundaries on the y axis.
    knot_slopes: array of shape [num_bins + 1], the slopes at the knot points.
  Returns:
    A tuple of two scalars: the output of the inverse transformation and the log
    of the absolute first derivative of the inverse at `y`.
  """
    # Search to find the right bin. NOTE: The bins are sorted, so we could use
    # binary search, but this is more GPU/TPU friendly.
    i = jnp.sum(y_pos < y) - 1
    below_range = y <= y_pos[0]
    above_range = y >= y_pos[-1]
    outside_range = jnp.logical_or(below_range, above_range)
    # Set i = 0 when y is out of range, to avoid NaNs later.
    i = jnp.where(outside_range, 0, i)
    bin_width = x_pos[i + 1] - x_pos[i]
    bin_height = y_pos[i + 1] - y_pos[i]
    bin_slope = bin_height / bin_width
    w = (y - y_pos[i]) / bin_height
    w = jnp.clip(w, 0., 1.)  # Ensure w is in [0, 1].
    # Compute quadratic coefficients: az^2 + bz + c = 0
    slopes_term = knot_slopes[i + 1] + knot_slopes[i] - 2. * bin_slope
    c = -bin_slope * w
    b = knot_slopes[i] - slopes_term * w
    a = bin_slope - b
    # Solve quadratic to obtain z and then x.
    z = -2. * c / (b + jnp.sqrt(b**2 - 4. * a * c))
    z = jnp.clip(z, 0., 1.)  # Ensure z is in [0, 1].
    x = bin_width * z + x_pos[i]
    # Compute log det Jacobian.
    sq_z = z * z
    z1mz = z - sq_z  # z(1-z)
    sq_1mz = (1. - z)**2
    denominator = bin_slope + slopes_term * z1mz
    logdet = -2. * jnp.log(bin_slope) - jnp.log(
        knot_slopes[i + 1] * sq_z + 2. * bin_slope * z1mz +
        knot_slopes[i] * sq_1mz) + 2. * jnp.log(denominator)
    # If y is outside the spline range, we default to a linear transformation.
    x = jnp.where(below_range, (y - y_pos[0]) / knot_slopes[0] + y_pos[0], x)
    x = jnp.where(above_range, (y - y_pos[-1]) / knot_slopes[-1] + y_pos[-1],
                  x)
    logdet = jnp.where(below_range, -jnp.log(knot_slopes[0]), logdet)
    logdet = jnp.where(above_range, -jnp.log(knot_slopes[-1]), logdet)
    return x, logdet