def relative_per_iteration_progress_test(step, rs_norm, obj_val, obj_arr, tol): """Relative per-iteration progress test proposed by Martens (2010). Terminate CG if: step > k, f_value(step) < 0, and (f_value(step) - f_value(step-k)) / f_value(step) < k * eps. For more inforamtion, see Section 4.4 of https://www.cs.toronto.edu/~jmartens/docs/Deep_HessianFree.pdf. Args: step: An integer value of the iteration step counter. rs_norm: A residual norm. obj_val: A current objective value. obj_arr: A jax.numpy array of objective values in recent steps. tol: The convergence tolerance. Returns: A bool value indicating if the test is satisfied. """ del rs_norm k = jnp.where(jnp.less(10, step // 10), step // 10, 10) arr_len = len(obj_arr) step_condition = jnp.less(k, step) negativity_condition = jnp.less(obj_val, 0.) progress_condition = jnp.less( k * obj_val * tol, obj_val - obj_arr[(step + arr_len - k) % arr_len]) return step_condition & negativity_condition & progress_condition
def calculate_best_position(objective_values, best_particle_cost, particles_position, best_particle_position, particles, dimensions): bests = npj.less(objective_values, best_particle_cost) reshape = npj.reshape(bests, npj.array([particles, 1])) bests_reshape = npj.broadcast_to(reshape, npj.array([particles, dimensions])) pos = npj.where(bests_reshape, particles_position, best_particle_position) return pos
def loop_body(inputs): rng, parameters, summaries, distances, n_accepted, iteration = \ inputs rng, key = jax.random.split(rng) parameter_samples = self.prior.sample(n_simulations, seed=key) rng, key = jax.random.split(rng) summary_samples = self.compressor( self.simulator(key, parameter_samples)) distance_samples = jax.vmap( lambda target, F: self.distance_measure( summary_samples, target, F))(self.target_summaries, self.F) indices = jax.lax.dynamic_slice( np.arange(n_simulations * max_iterations), [n_simulations * iteration], [n_simulations]) parameters = jax.ops.index_update(parameters, jax.ops.index[indices], parameter_samples) summaries = jax.ops.index_update(summaries, jax.ops.index[indices], summary_samples) distances = jax.ops.index_update(distances, jax.ops.index[:, indices], distance_samples) n_accepted = np.int32(np.less(distances, ϵ).sum(1)) return rng, parameters, summaries, distances, n_accepted, \ iteration + np.int32(1)
def single_acceptance_condition(args): """checks proposal has been accepted or max iterations reached Parameters ---------- args : tuple see loop variable in `single_iteration` Returns ------- bool: True if proposal not accepted and number of attempts to get an accepted proposal not yet reached """ return np.logical_and(np.less(args[-2], 1), np.less(args[-1], max_acceptance))
def bi_tempered_logistic_loss_fwd(activations, labels, t1, t2, label_smoothing=0.0, num_iters=5): """Forward pass function for bi-tempered logistic loss. Args: activations: A multi-dimensional array with last dimension `num_classes`. labels: An array with shape and dtype as activations. t1: Temperature 1 (< 1.0 for boundedness). t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support). label_smoothing: Label smoothing parameter between [0, 1). num_iters: Number of iterations to run the method. Returns: A loss array, residuals. """ num_classes = jnp.int32(labels.shape[-1]) labels = cond( label_smoothing > 0.0, lambda u: # pylint: disable=g-long-lambda (1 - num_classes / (num_classes - 1) * label_smoothing) * u + label_smoothing / (num_classes - 1), lambda u: u, labels) probabilities = tempered_softmax(activations, t2, num_iters) def _tempred_cross_entropy_loss(unused_activations): loss_values = jnp.multiply( labels, log_t(labels + 1e-10, t1) - log_t(probabilities, t1)) - 1.0 / (2.0 - t1) * ( jnp.power(labels, 2.0 - t1) - jnp.power(probabilities, 2.0 - t1)) loss_values = jnp.sum(loss_values, -1) return loss_values loss_values = cond( jnp.logical_and( jnp.less(jnp.abs(t1 - 1.0), 1e-15), jnp.less(jnp.abs(t2 - 1.0), 1e-15)), functools.partial(_cross_entropy_loss, labels=labels), _tempred_cross_entropy_loss, activations) return loss_values, (labels, t1, t2, probabilities)
def termination_condition(state): *_, step, rs_norm, obj_val, obj_arr = state return jnp.logical_and( jnp.less(step, max_iter), jnp.equal( termination_criterion_fn( rs_norm=rs_norm, tol=tol, step=step-1, obj_val=obj_val, obj_arr=obj_arr), False))
def keep_step(grad_norm): keep_threshold = p.skip_step_gradient_norm_value if keep_threshold: return jnp.logical_and( jnp.all(jnp.isfinite(grad_norm)), jnp.all(jnp.less(grad_norm, keep_threshold))) else: return jnp.all(jnp.isfinite(grad_norm))
def insert(m, r, i): n = m.shape[0] a = np.concatenate([m, r[np.newaxis,:]], axis=0) before_inds = np.arange(n+1)*np.less(np.arange(n+1),i) after_inds = (np.arange(n+1)-1)*np.greater(np.arange(n+1),i) new_ind = np.ones(shape=[n+1], dtype=np.int32)*np.equal(np.arange(n+1),i)*n inds = before_inds + after_inds + new_ind return a[inds]
def i_stimulus(t, params_and_data): return lax.cond( np.logical_or( np.less(t, params_and_data["stimulus_start_time"]), np.greater(t, params_and_data["stimulus_end_time"]), ), lambda _: 0.0, lambda _: params_and_data["i_stimulus"], None, )
def _sample_n(self, key: PRNGKey, n: int) -> Array: """See `Distribution._sample_n`.""" probs = self.probs new_shape = (n, ) + probs.shape uniform = jax.random.uniform(key=key, shape=new_shape, dtype=probs.dtype, minval=0., maxval=1.) return jnp.less(uniform, probs).astype(self._dtype)
def get_accepted(distances, ϵ): """ Returns a boolean array with whether summaries are within `ϵ` Parameters ---------- distances : float(any) The distances between the summary of a target and the summaries of the run simulations ϵ : float The acceptance distance between summaries from simulations and the summary of the target data """ return np.less(distances, ϵ)
def residual_norm_test(step, rs_norm, obj_val, obj_arr, tol): """Residual norm test, terminates CG if sqrt(rs_norm) < tol. Args: step: An integer value of the iteration step counter. rs_norm: A residual norm. obj_val: A current objective value. obj_arr: A jax.numpy array of objective values in recent steps. tol: The convergence tolerance. Returns: A bool value indicating if the test is satisfied. """ del step, obj_val, obj_arr return jnp.less(jnp.sqrt(rs_norm), tol)
def single_iteration_condition(args): """Checks if the acceptance ratio or maximum iterations is reached Parameters ---------- args : tuple loop variables (described in `single_iteration`) Returns ------- bool: True if acceptance_ratio is reached or the maximum number of iterations is reached """ return np.logical_and(np.greater(args[-3], acceptance_ratio), np.less(args[-2], max_iteration))
def less(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.less(x1, x2))
def top2_gating_on_logits(paddings, logits, experts_dim, expert_capacity_dim, fprop_dtype, prng_key, second_expert_policy='all', second_expert_threshold=0.0, legacy_mtf_behavior=True, capacity_factor=None, importance=None, mask_dtype=jnp.int32): """Computes Top-2 gating for Mixture-of-Experts. This function takes gating logits, potentially sharded across tpu cores as inputs. We rely on sharding propagation to work universally with 1D and 2D sharding cases. Dispatch and combine tensors should be explicitly annotated with jax.with_sharding_constraint by the caller. We perform dispatch/combine via einsum. Dimensions: G: group dim S: group size dim E: number of experts C: capacity per expert M: model_dim (same as input_dim and output_dim as in FF layer) B: original batch dim L: original seq len dim Note that for local_dispatch, the original batch BLM is reshaped to GSM, each group `g = 0..G-1` is being dispatched independently. Args: paddings: G`S tensor. logits: G`SE tensor. experts_dim: number of experts expert_capacity_dim: number of examples per minibatch/group per expert. Each example is typically a vector of size input_dim, representing embedded token or an element of Transformer layer output. fprop_dtype: activation dtype prng_key: jax.random.PRNGKey used for randomness. second_expert_policy: 'all', 'sampling' or 'random' - 'all': we greedily pick the 2nd expert - 'sampling': we sample the 2nd expert from the softmax - 'random': we optionally randomize dispatch to second-best expert in proportional to (weight / second_expert_threshold). second_expert_threshold: threshold for probability normalization when second_expert_policy == 'random' legacy_mtf_behavior: bool, True if to match legacy mtf behavior exactly. capacity_factor: if set, increases expert_capacity_dim to at least (group_size * capacity_factor) / experts_dim importance: input importance weights for routing (G`S tensor or None) mask_dtype: using bfloat16 for fprop_dtype could be problematic for mask tensors, mask_dtype overrides dtype for such tensors Returns: A tuple (aux_loss, combine_tensor, dispatch_tensor, over_capacity ratios). - aux_loss: auxiliary loss, for equalizing the expert assignment ratios. - combine_tensor: a G`SEC tensor for combining expert outputs. - dispatch_tensor: a G`SEC tensor, scattering/dispatching inputs to experts. - over_capacity ratios: tuple that represents the ratio of tokens that were not dispatched due to lack of capcity for top_1 and top_2 expert respectively, e.g. (over_capacity_1, over_capacity_2) """ assert (capacity_factor or expert_capacity_dim) if mask_dtype is None: assert fprop_dtype != jnp.bfloat16, 'Using bfloat16 for mask is an error.' mask_dtype = fprop_dtype raw_gates = jax.nn.softmax(logits, axis=-1) # along E dim if raw_gates.dtype != fprop_dtype: raw_gates = raw_gates.astype(fprop_dtype) if capacity_factor is not None: # Determine expert capacity automatically depending on the input size group_size_dim = logits.shape[1] auto_expert_capacity = int(group_size_dim * capacity_factor / experts_dim) if expert_capacity_dim < auto_expert_capacity: expert_capacity_dim = auto_expert_capacity # Round up to a multiple of 4 to avoid possible padding. while expert_capacity_dim % 4: expert_capacity_dim += 1 logging.info( 'Setting expert_capacity_dim=%r (capacity_factor=%r ' 'group_size_dim=%r experts_dim=%r)', expert_capacity_dim, capacity_factor, group_size_dim, experts_dim) capacity = jnp.array(expert_capacity_dim, dtype=jnp.int32) # top-1 index: GS tensor index_1 = jnp.argmax(raw_gates, axis=-1) # GSE mask_1 = jax.nn.one_hot(index_1, experts_dim, dtype=mask_dtype) density_1_proxy = raw_gates if importance is not None: importance_is_one = jnp.equal(importance, 1.0) mask_1 *= jnp.expand_dims(importance_is_one.astype(mask_1.dtype), -1) density_1_proxy *= jnp.expand_dims( importance_is_one.astype(density_1_proxy.dtype), -1) else: assert len(mask_1.shape) == 3 importance = jnp.ones_like(mask_1[:, :, 0]).astype(fprop_dtype) if paddings is not None: nonpaddings = 1.0 - paddings mask_1 *= jnp.expand_dims(nonpaddings.astype(mask_1.dtype), -1) density_1_proxy *= jnp.expand_dims( nonpaddings.astype(density_1_proxy.dtype), -1) importance = nonpaddings gate_1 = jnp.einsum('GSE,GSE->GS', raw_gates, mask_1.astype(raw_gates.dtype)) gates_without_top_1 = raw_gates * (1.0 - mask_1.astype(raw_gates.dtype)) if second_expert_policy == 'sampling': # We directly sample the 2nd expert index from the softmax over of the 2nd # expert by getting rid of the 1st expert already selected above. To do so, # we set a very negative value to the logit corresponding to the 1st expert. # Then we sample from the softmax distribution using the Gumbel max trick. prng_key, subkey = jax.random.split(prng_key) noise = jax.random.uniform(subkey, logits.shape, dtype=logits.dtype) # Generates standard Gumbel(0, 1) noise, GSE tensor. noise = -jnp.log(-jnp.log(noise)) very_negative_logits = jnp.ones_like(logits) * (-0.7) * np.finfo( logits.dtype).max # Get rid of the first expert by setting its logit to be very negative. updated_logits = jnp.where(mask_1 > 0.0, very_negative_logits, logits) # Add Gumbel noise to the updated logits. noised_logits = updated_logits + noise # Pick the index of the largest noised logits as the 2nd expert. This is # equivalent to sampling from the softmax over the 2nd expert. index_2 = jnp.argmax(noised_logits, axis=-1) else: # Greedily pick the 2nd expert. index_2 = jnp.argmax(gates_without_top_1, axis=-1) mask_2 = jax.nn.one_hot(index_2, experts_dim, dtype=mask_dtype) if paddings is not None: importance_is_nonzero = importance > 0.0 mask_2 *= jnp.expand_dims(importance_is_nonzero.astype(mask_2.dtype), -1) gate_2 = jnp.einsum('GSE,GSE->GS', gates_without_top_1, mask_2.astype(gates_without_top_1.dtype)) # See notes in lingvo/core/gshard_layers.py. if legacy_mtf_behavior: # Renormalize. denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # We reshape the mask as [X*S, E], and compute cumulative sums of assignment # indicators for each expert index e \in 0..E-1 independently. # First occurrence of assignment indicator is excluded, see exclusive=True # flag below. # cumsum over S dim: mask_1 is GSE tensor. position_in_expert_1 = cum_sum(mask_1, exclusive=True, axis=-2) # GE tensor (reduce S out of GSE tensor mask_1). # density_1[:, e] represents assignment ration (num assigned / total) to # expert e as top_1 expert without taking capacity into account. assert importance.dtype == fprop_dtype if legacy_mtf_behavior: density_denom = 1.0 else: density_denom = jnp.mean(importance, axis=1)[:, jnp.newaxis] + 1e-6 density_1 = jnp.mean(mask_1.astype(fprop_dtype), axis=-2) / density_denom # density_1_proxy[:, e] represents mean of raw_gates for expert e, including # those of examples not assigned to e with top_k density_1_proxy = jnp.mean(density_1_proxy, axis=-2) / density_denom # Compute aux_loss aux_loss = jnp.mean(density_1_proxy * density_1) # element-wise aux_loss *= (experts_dim * experts_dim) # const coefficients # Add the over capacity ratio for expert 1 over_capacity_1 = _create_over_capacity_ratio_summary( mask_1, position_in_expert_1, capacity, 'over_capacity_1') mask_1 *= jnp.less(position_in_expert_1, expert_capacity_dim).astype(mask_1.dtype) position_in_expert_1 = jnp.einsum('GSE,GSE->GS', position_in_expert_1, mask_1) # How many examples in this sequence go to this expert? mask_1_count = jnp.einsum('GSE->GE', mask_1) # [batch, group] - mostly ones, but zeros where something didn't fit. mask_1_flat = jnp.sum(mask_1, axis=-1) assert mask_1_count.dtype == mask_dtype assert mask_1_flat.dtype == mask_dtype if second_expert_policy == 'all' or second_expert_policy == 'sampling': pass else: assert second_expert_policy == 'random' # gate_2 is between 0 and 1, reminder: # # raw_gates = jax.nn.softmax(logits) # index_1 = jnp.argmax(raw_gates, axis=-1) # mask_1 = jax.nn.one_hot(index_1, experts_dim, dtpe=fprop_dtype) # gate_1 = jnp.einsum(`GSE,GSE->GS', raw_gates, mask_1) # # e.g., if gate_2 exceeds second_expert_threshold, then we definitely # dispatch to second-best expert. Otherwise, we dispatch with probability # proportional to (gate_2 / threshold). # prng_key, subkey = jax.random.split(prng_key) sampled_2 = jnp.less( jax.random.uniform(subkey, gate_2.shape, dtype=gate_2.dtype), gate_2 / max(second_expert_threshold, 1e-9)) gate_2 *= sampled_2.astype(gate_2.dtype) mask_2 *= jnp.expand_dims(sampled_2, -1).astype(mask_2.dtype) position_in_expert_2 = cum_sum( mask_2, exclusive=True, axis=-2) + jnp.expand_dims(mask_1_count, -2) over_capacity_2 = _create_over_capacity_ratio_summary( mask_2, position_in_expert_2, capacity, 'over_capacity_2') mask_2 *= jnp.less(position_in_expert_2, expert_capacity_dim).astype(mask_2.dtype) position_in_expert_2 = jnp.einsum('GSE,GSE->GS', position_in_expert_2, mask_2) mask_2_flat = jnp.sum(mask_2, axis=-1) gate_1 *= mask_1_flat.astype(gate_1.dtype) gate_2 *= mask_2_flat.astype(gate_2.dtype) if not legacy_mtf_behavior: denom = gate_1 + gate_2 # To avoid divide by 0. denom = jnp.where(denom > 0, denom, jnp.ones_like(denom)) gate_1 /= denom gate_2 /= denom # GSC tensor b = jax.nn.one_hot(position_in_expert_1.astype(np.int32), expert_capacity_dim, dtype=fprop_dtype) # GSE tensor a = jnp.expand_dims(gate_1 * mask_1_flat.astype(fprop_dtype), axis=-1) * jax.nn.one_hot( index_1, experts_dim, dtype=fprop_dtype) # GSEC tensor first_part_of_combine_tensor = jnp.einsum('GSE,GSC->GSEC', a, b) # GSC tensor b = jax.nn.one_hot(position_in_expert_2.astype(np.int32), expert_capacity_dim, dtype=fprop_dtype) # GSE tensor a = jnp.expand_dims(gate_2 * mask_2_flat.astype(fprop_dtype), axis=-1) * jax.nn.one_hot( index_2, experts_dim, dtype=fprop_dtype) second_part_of_combine_tensor = jnp.einsum('GSE,GSC->GSEC', a, b) # GSEC tensor combine_tensor = first_part_of_combine_tensor + second_part_of_combine_tensor # GSEC tensor dispatch_tensor = combine_tensor.astype(bool).astype(fprop_dtype) return aux_loss, combine_tensor, dispatch_tensor, (over_capacity_1, over_capacity_2)
def lt(a: Numeric, b: Numeric): return jnp.less(a, b)
def cond_func(args): i, _, _ = args return jnp.less(i, total_count)
def read_img_mask(filename): img = Image.open(filename) I = np.transpose(np.array(img)) #(np.asfortranarray(img)) mask = np.array(np.less(np.zeros(I.shape), I), dtype=np.uint8) return mask
def spaced_mean_cond(state): unused_key, means, mask, sample = state dists = mask * jax.vmap(dist, in_axes=(0, None))(means, sample) return jnp.any(jnp.less(dists, min_distance))
def loop_cond(inputs): return np.logical_and(np.less(np.min(inputs[-2]), accepted), np.less(inputs[-1], max_iterations))
is_non_decreasing = utils.copy_docstring( tf.math.is_non_decreasing, lambda x, name=None: np.all(x[1:] >= x[:-1])) is_strictly_increasing = utils.copy_docstring( tf.math.is_strictly_increasing, lambda x, name=None: np.all(x[1:] > x[:-1])) l2_normalize = utils.copy_docstring( tf.math.l2_normalize, lambda x, axis=None, epsilon=1e-12, name=None: ( # pylint: disable=g-long-lambda np.linalg.norm(x, ord=2, axis=axis, keepdims=True))) lbeta = utils.copy_docstring(tf.math.lbeta, _lbeta) less = utils.copy_docstring(tf.math.less, lambda x, y, name=None: np.less(x, y)) less_equal = utils.copy_docstring(tf.math.less_equal, lambda x, y, name=None: np.less_equal(x, y)) lgamma = utils.copy_docstring(tf.math.lgamma, lambda x, name=None: scipy_special.gammaln(x)) log = utils.copy_docstring(tf.math.log, lambda x, name=None: np.log(x)) log1p = utils.copy_docstring(tf.math.log1p, lambda x, name=None: np.log1p(x)) log_sigmoid = utils.copy_docstring(tf.math.log_sigmoid, lambda x, name=None: -np.log1p(np.exp(-x))) log_softmax = utils.copy_docstring(
def w_cond(self, args): _, loc, counter = args return np.logical_and( np.logical_or(np.any(np.greater(loc, self.high)), np.any(np.less(loc, self.low))), np.less(counter, self.max_counter))
def _less(a, b): return jnp.less(a, b)
def is_u_turning(q_i, q_f, p_f): return np.less(np.dot((q_f['z'] - q_i['z']), p_f['z']), 0)
def map_body(k): return jax.lax.cond(np.less(k, K + 1), inner_map_body, lambda x: -np.inf, k)
def single_acceptance(args): """Draws a proposal, simulates and compresses, checks distance A new proposal is drawn from a truncated multivariate normal distribution whose mean is centred on the parameter to move and the covariance is set by the population. From this proposed parameter value a simulation is made and compressed and the distance from the target is calculated. If this distance is less than the current position then the proposal is accepted. Parameters ---------- args : tuple see loop variable in `single_iteration` Returns ------- bool: True if proposal not accepted and number of attempts to get an accepted proposal not yet reached Todo ---- Parallel sampling is currently commented out """ (rng, loc, scale, summ, dis, draws, accepted, acceptance_counter) = args rng, key = jax.random.split(rng) proposed, summaries = self.get_samples( key, None, dist=tmvn(loc, scale, self.prior.low, self.prior.high, max_counter=max_samples)) distances = np.squeeze( self.distance_measure(np.expand_dims(summaries, 0), target, F)) # if n_parallel_simulations is not None: # min_distance_index = np.argmin(distances) # min_distance = distances[min_distance_index] # closer = np.less(min_distance, ϵ) # loc = jax.lax.cond( # closer, # lambda _ : proposed[min_distance_index], # lambda _ : loc, # None) # summ = jax.lax.cond( # closer, # lambda _ : summaries[min_distance_index], # lambda _ : summ, # None) # dis = jax.lax.cond( # closer, # lambda _ : distances[min_distance_index], # lambda _ : dis, # None) # iteration_draws = n_parallel_simulations \ # - np.isinf(distances).sum() # draws += iteration_draws # accepted = closer.sum() # else: closer = np.less(distances, np.min(dis)) loc = jax.lax.cond(closer, lambda _: proposed, lambda _: loc, None) summ = jax.lax.cond(closer, lambda _: summaries, lambda _: summ, None) dis = jax.lax.cond(closer, lambda _: distances, lambda _: dis, None) iteration_draws = 1 - np.isinf(distances).sum() draws += iteration_draws accepted = closer.sum() return (rng, loc, scale, summ, dis, draws, accepted, acceptance_counter + 1)
def delete(m, i): n = m.shape[0] before_inds = np.arange(n-1)*np.less(np.arange(n-1),i) after_inds = (np.arange(n-1)+1)*np.greater(np.arange(n-1)+1,i) inds = before_inds + after_inds return m[inds]
def born_radii(conf, atomic_radii, scaled_radius_factor, dielectric_offset, alpha_obc, beta_obc, gamma_obc): """ Compute the adjusted born radii of each atom. This is the first part of the GBSA calculation. Parameters ---------- conf: np.array shape Nx3 matrix of geometric coordinates atomic_radii: np.array shape [N,] array of radius of each atom scaled_radius_factor: np.array shape [N,] array of adjusted shape factors for each atom. Returns ------- np.array shape [N,] np.array of atomic radiis """ num_atoms = conf.shape[0] r_i = np.expand_dims(conf, axis=0) r_j = np.expand_dims(conf, axis=1) d_ij = distance(r_i, r_j) oR = atomic_radii - dielectric_offset oRI = np.expand_dims(oR, axis=1) # rows oRJ = np.expand_dims(oR, axis=0) # columns sRJ = oRJ * scaled_radius_factor rSRJ = d_ij + sRJ # along the diagonal rSRJ < oRI, resulting in a mask whose # diagonals are strictly false. mask_final = np.less(oRI, rSRJ) d_ij_inv = 1 / d_ij # 1/d_ij has NaNs along diagonals so we need to zero it out keep_mask = 1 - np.eye(conf.shape[0]) d_ij_inv = np.where(keep_mask, d_ij_inv, np.zeros_like(d_ij_inv)) rfs = np.abs(d_ij - sRJ) l_ij = np.maximum(oRI, rfs) l_ij = 1 / l_ij u_ij = 1 / rSRJ l_ij2 = l_ij * l_ij u_ij2 = u_ij * u_ij ratio = np.log(u_ij / l_ij) term = l_ij - u_ij + 0.25 * d_ij * (u_ij2 - l_ij2) + ( 0.5 * d_ij_inv * ratio) + (0.25 * sRJ * sRJ * d_ij_inv) * (l_ij2 - u_ij2) term_masked = np.where(mask_final, term, np.zeros_like(term)) summ = np.sum(term_masked, axis=-1) summ *= 0.5 * oR sum2 = summ * summ sum3 = summ * sum2 tanhSum = np.tanh(alpha_obc * summ - beta_obc * sum2 + gamma_obc * sum3) return 1.0 / (1.0 / oR - tanhSum / atomic_radii)