def loop_body(val): """From ids at `step`, update output ids at `step + 1`.""" step = val.step decoder_state, logits = extend_step_fn(val.state, val.output_ids[:, step]) logprobs = jax.nn.log_softmax(logits.astype(jnp.float32)) val.state = decoder_state # When step becomes prefix_length - 1, the new output has index beyond # the known prefix. # If prefix_length is 0, the condition is always False, so we take the # decoded output rather than the prefix. new_ids = jnp.where(step < prefix_lengths - 1, target_ids[:, step + 1], jnp.argmax(logits, axis=1)) prev_done = val.done new_ids = jnp.where(prev_done, jnp.zeros_like(new_ids), new_ids) if eos_id is not None: val.done = jnp.logical_or(prev_done, jnp.equal(new_ids, eos_id)) max_decoding_steps_reached = (jnp.ones_like(prefix_lengths) * (step + 2) - prefix_lengths) >= max_decode_steps val.done = jnp.logical_or(val.done, max_decoding_steps_reached) done_at_this_step = jnp.logical_and(jnp.logical_not(prev_done), val.done) val.decode_lengths = jnp.where( done_at_this_step, jnp.ones_like(val.decode_lengths) * (step + 2), val.decode_lengths) val.output_ids = val.output_ids.at[:, step + 1].set(new_ids) logprobs_at_new_ids = logprobs.at[jnp.arange(batch_size), new_ids].get() logprobs_at_new_ids = jnp.where(prev_done, jnp.ones_like(logprobs_at_new_ids), logprobs_at_new_ids) val.logprobs = val.logprobs.at[:, step + 1].set(logprobs_at_new_ids) val.step += 1 return val
def norm_logbicop_approx(u,v,rho): pu = jnp.where(u==0.5,x =0.5,y = ndtri_(u)) pv = jnp.where(v==0.5,x =0.5,y = ndtri_(v)) alpha_u = jnp.where(u==0.5,x =0,y = (1/jnp.sqrt(1-rho**2))*((pv/pu)-rho)) alpha_v = jnp.where(v==0.5,x =0, y = (1/jnp.sqrt(1-rho**2))*((pu/pv)-rho)) rho_u = -alpha_u /(jnp.sqrt(1+alpha_u**2)) rho_v = -alpha_v /(jnp.sqrt(1+alpha_v**2)) C_uu = jnp.exp(norm_logbicop_diag_approx(jnp.log(u),1-(2*rho_u**2))) C_vv = jnp.exp(norm_logbicop_diag_approx(jnp.log(v),1-(2*rho_v**2))) ind_rho_u = jnp.where(rho_u <0,x = 1,y = 0) C_half_u = 0.5*ind_rho_u*C_uu + (1-ind_rho_u)*(u-0.5*C_uu) ind_rho_v = jnp.where(rho_v <0,x = 1,y = 0) C_half_v = 0.5*ind_rho_v*C_vv + (1-ind_rho_v)*(v-0.5*C_vv) delta_uv = jnp.where(jnp.logical_or(jnp.logical_and(u<0.5,v>=0.5), jnp.logical_and(u>=0.5, v<0.5)), x = 0.5, y = 0) C_uv = C_half_u + C_half_v -delta_uv ind_v_half = jnp.where(u==0.5, x = 1, y = 0) ind_u_half = jnp.where(v==0.5, x = 1, y = 0) ind_uv_half_or = jnp.logical_or(u==0.5,v==0.5) ind_uv_half_and = jnp.logical_and(u==0.5,v==0.5) C_uv = (1-ind_uv_half_and)*ind_u_half*(u-0.5*jnp.exp(norm_logbicop_diag_approx(jnp.log(u),1-2*rho**2)))+ \ (1-ind_uv_half_and)*ind_v_half*(v-0.5*jnp.exp(norm_logbicop_diag_approx(jnp.log(v),1-2*rho**2)))+\ ind_uv_half_and*jnp.exp(norm_logbicop_diag_approx(jnp.log(u),rho)) +\ (1-ind_uv_half_or)*C_uv return jnp.log(C_uv)
def has_converged(x: Tensor, grad: Tensor, l: Tensor, u: Tensor): stuck_at_lower = jnp.logical_and(x == l, grad >= 0) stuck_at_upper = jnp.logical_and(x == u, grad <= 0) zero_grad = grad == 0 stuck_at_border = jnp.logical_or(stuck_at_lower, stuck_at_upper) converged = jnp.logical_or(stuck_at_border, zero_grad) return jnp.all(converged)
def cond_func(val): q, i, norm_delta_q, error, = val diverged = np.logical_or(error > divergence_tol, np.isnan(error)) converged = np.logical_and(error < convergence_tol, norm_delta_q < position_tol) return np.logical_not( np.logical_or((i >= max_iters), np.logical_or(diverged, converged)))
def compute_intersection_point(denom): t1 = np.cross(v2, v1) / denom t2 = (v1 @ v3) / denom condition = np.logical_or(np.logical_or(t1 < 0.0, t2 < 0.0), t2 > 1.0) return jax.lax.cond( condition, true_fun=lambda t: np.array([np.inf, np.inf]), false_fun=lambda t: ray_origin + t1 * ray_direction, operand=t1, )
def adaptive_hmc_update(state, log_prob, state_grad, key, step_size, trajectory_len, target_accept_rate=0.8, step_size_adaptation_speed=0.05, max_n_leapfrog=1000, jitter_amt=0.2): normal_key, uniform_key, jitter_key = jax.random.split(key, 3) n_leapfrog = jnp.array(jnp.ceil(trajectory_len / step_size), jnp.int32) n_leapfrog = jnp.minimum(n_leapfrog, max_n_leapfrog) jittered_step_size = step_size * jnp.exp( jnp.where( jnp.logical_or(step_size_adaptation_speed <= 0, target_accept_rate <= 0), jnp.log(1. + jitter_amt) * (2 * jax.random.uniform(jitter_key, ()) - 1.), 0.)) num_leaves = len(jax.tree_leaves(state)) normal_keys = list(jax.random.split(normal_key, num_leaves)) treedef = jax.tree_structure(state) normal_keys = jax.tree_unflatten(treedef, normal_keys) momentum = jax.tree_multimap( lambda s, key: jax.random.normal(key, s.shape), state, normal_keys) initial_energy = get_kinetic_energy(momentum) - log_prob new_state, new_momentum, new_grad, new_log_prob = leapfrog( jittered_step_size, n_leapfrog, state, momentum, state_grad) new_energy = _nan_to_inf( get_kinetic_energy(new_momentum) - new_log_prob) energy_diff = initial_energy - new_energy accept_prob = jnp.minimum(1., jnp.exp(energy_diff)) # TODO(izmailovpavel): check why the second condition is needed. accepted = jnp.logical_and( jax.random.uniform(uniform_key, log_prob.shape) < accept_prob, jnp.isfinite(energy_diff)) step_size = step_size * jnp.exp( jnp.where( jnp.logical_or(target_accept_rate <= 0, step_size_adaptation_speed <= 0), 0., step_size_adaptation_speed * (jnp.mean(accept_prob) - target_accept_rate))) state = jax.lax.cond(accepted, _first, _second, (new_state, state)) log_prob = jnp.where(accepted, new_log_prob, log_prob) state_grad = jax.lax.cond(accepted, _first, _second, (new_grad, state_grad)) return state, log_prob, state_grad, step_size, accept_prob
def cond_fun(maxiter, bound, feasStop, state): logging.info('compiling cond_fun') counter = state[1] dual_objective, primal_dual_gap, maxfeasible = state[7:10] cond1 = counter <= maxiter cond2 = dual_objective < bound cond3 = np.logical_or( np.absolute(primal_dual_gap) > 1e-6, np.logical_or(maxfeasible > feasStop, np.logical_and(counter < 200, maxfeasible >= 0))) return np.logical_and(cond1, np.logical_and(cond2, cond3))
def check_convergence( self, state_prev: _NonlinearSolverState, cost_updated: hints.Scalar, local_delta_assignments: VariableAssignments, negative_gradient: hints.Array, ) -> bool: """Check for convergence!""" # Cost tolerance converged_cost = ( jnp.abs(cost_updated - state_prev.cost) / state_prev.cost < self.cost_tolerance ) # Gradient tolerance converged_gradient = jnp.where( state_prev.iterations >= self.gradient_tolerance_start_step, jnp.max( state_prev.assignments.storage - state_prev.assignments.manifold_retract( VariableAssignments( storage=negative_gradient, storage_metadata=local_delta_assignments.storage_metadata, ), ).storage ) < self.gradient_tolerance, False, ) # Parameter tolerance converged_parameters = ( jnp.linalg.norm(jnp.abs(local_delta_assignments.storage)) < ( jnp.linalg.norm(state_prev.assignments.storage) + self.parameter_tolerance ) * self.parameter_tolerance ) return jnp.logical_or( converged_cost, jnp.logical_or( converged_gradient, converged_parameters, ), )
def _build_sliding_window_mask(window_size, global_mask): """Builds mask for sliding window pattern. Args: window_size: int, size of sliding window. global_mask: boolean jax array of shape `[batch_size, seq_len]`. Returns: mask, boolean jax array of shape `[batch_size, 1 (n_heads), seq_len, seq_len]`. If `window_size` is odd, both left and right sides have the same receptive field. Otherwise, the left side gets one more. Note - we need global mask because due to the symmetry requirement, non-global positions can still attend to global positions. """ seq_len = global_mask.shape[1] right_size = window_size // 2 left_size = window_size - right_size left_mask = sum(np.eye(seq_len, k=-i) for i in range(left_size)) right_mask = sum(np.eye(seq_len, k=i) for i in range(1, right_size + 1)) mask = left_mask + right_mask mask = jnp.array(mask[np.newaxis, np.newaxis, :, :]).astype(jnp.bool_) return jnp.logical_or(mask, _build_global_mask(global_mask))
def cond_fn(iteration, const, state): threshold = const[-1] errors = state[0] err = errors[iteration // inner_iterations-1, 0] return jnp.logical_or(iteration == 0, jnp.logical_and(jnp.isfinite(err), err > threshold))
def __init__(self, n: jnp.ndarray, p: jnp.ndarray): """Initializes a multinomial distribution with n trials and probabilities p. n may be multidimensional, in which case it represents multiple multinomial distributions. p has to have the shape of n plus 1 dimension representing the the probabilities of each event. The probabilities in the last dimension have to sum to 1. Args: n: Number of trials. Has to be an integer and non-negative. p: Probabilities of trial successes. Must have same shape as n + 1 additional dimension representing the probabilities. Probabilities have to sum to 1. """ super().__init__() if n.shape != p.shape[:len(n.shape)] or \ len(n.shape) + 1 != len(p.shape): raise ValueError('Shapes of n and p not compatible') # we cannot raise a ValueError here since we get problems with # ConcretizationError during Metropolis-Hastings nans = jnp.full(p.shape, jnp.nan) self.p = jnp.where(jnp.logical_or(p < 0, p > 1), nans, p) self.p = jnp.where(jnp.isclose(jnp.sum(self.p, -1, keepdims=True), 1), self.p, nans) nans = jnp.full(n.shape, jnp.nan) self.n = jnp.where(n <= 0, nans, n)
def adapt_step_size(step_size, target_accept_rate, accept_prob, step_size_adaptation_speed): log_factor = jnp.where( jnp.logical_or(target_accept_rate <= 0, step_size_adaptation_speed <= 0), 0., step_size_adaptation_speed * (accept_prob - target_accept_rate)) return step_size * jnp.exp(log_factor)
def update(updates, state, params=None): inner_state = state.inner_state flat_updates = tree_flatten(updates)[0] isfinite = jnp.all( jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates])) notfinite_count = jnp.where(isfinite, jnp.zeros([], jnp.int64), 1 + state.notfinite_count) def do_update(_): return inner.update(updates, inner_state, params) def reject_update(_): return (tree_map(jnp.zeros_like, updates), inner_state) updates, new_inner_state = lax.cond(jnp.logical_or( isfinite, notfinite_count > max_consecutive_errors), do_update, reject_update, operand=None) return updates, ApplyIfFiniteState( notfinite_count=notfinite_count, last_finite=isfinite, total_notfinite=jnp.logical_not(isfinite) + state.total_notfinite, inner_state=new_inner_state)
def apply_fun(params, inputs, **kwargs): def inner_loop_body(i, inputs_and_counter): inputs = inputs_and_counter[0] counter = inputs_and_counter[1] val = inputs[j, i] condition = jnp.logical_or(val < -0.5, val > 0.5) inputs, counter = jax.lax.cond(condition, lambda xTrue: ( jax.ops.index_update(inputs, jax.ops.index[j, i], counter * val), counter * -1), lambda xFalse: (inputs, counter), (None)) return inputs, counter for j in range(inputs.shape[0]): counter = +1 #print('------------------') #print(inputs[j, :]) #flips every second nonzero spin for i in range(inputs.shape[1]): val = inputs[j, i] condition = jnp.logical_or(val < -0.5, val > 0.5) inputs, counter = jax.lax.cond(condition, lambda xTrue: (jax.ops.index_update(inputs, jax.ops.index[j, i], counter*val), counter * -1), lambda xFalse: (inputs, counter), (None)) # if(val < -0.5 or val > 0.5): # inputs = jax.ops.index_update(inputs, jax.ops.index[j, i], counter*val) # counter *= -1 #inputs, counter = jax.lax.fori_loop(0, inputs.shape[1], inner_loop_body, (inputs, counter)) #print(inputs[j, :]) return inputs
def sample_search_cond_fn(state): """state termination condition fn.""" has_reached_max_length = state.cur_len == max_length all_sequence_finished = jnp.all(state.is_sent_finished) finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished) return ~finish_generation
def apply_param_gradient(self, step, hyper_params, param, state, grad): del step assert hyper_params.learning_rate is not None, 'no learning rate provided.' param_norm = jnp.linalg.norm(param) grad_norm = jnp.linalg.norm(grad) trust_ratio = hyper_params.trust_coefficient * param_norm / ( grad_norm + hyper_params.weight_decay * param_norm + hyper_params.eps) clipped_trust_ratio = jnp.where( jnp.logical_or(grad_norm == 0., param_norm == 0.), 1., trust_ratio) scaled_lr = hyper_params.learning_rate * clipped_trust_ratio if hyper_params.weight_decay != 0: grad += hyper_params.weight_decay * param scaled_grad = scaled_lr * grad momentum = state.momentum new_momentum = hyper_params.beta * momentum + scaled_grad if hyper_params.nesterov: d_p = scaled_grad + hyper_params.beta * new_momentum else: d_p = new_momentum new_param = param - d_p new_state = _LARSParamState(new_momentum) return new_param, new_state
def _rational_quadratic_spline_fwd(x: Array, x_pos: Array, y_pos: Array, knot_slopes: Array) -> Tuple[Array, Array]: """Applies a rational-quadratic spline to a scalar. Args: x: a scalar (0-dimensional array). The scalar `x` can be any real number; it will be transformed by the spline if it's in the closed interval `[x_pos[0], x_pos[-1]]`, and it will be transformed linearly if it's outside that interval. x_pos: array of shape [num_bins + 1], the bin boundaries on the x axis. y_pos: array of shape [num_bins + 1], the bin boundaries on the y axis. knot_slopes: array of shape [num_bins + 1], the slopes at the knot points. Returns: A tuple of two scalars: the output of the transformation and the log of the absolute first derivative at `x`. """ # Search to find the right bin. NOTE: The bins are sorted, so we could use # binary search, but this is more GPU/TPU friendly. i = jnp.sum(x_pos < x) - 1 below_range = x <= x_pos[0] above_range = x >= x_pos[-1] outside_range = jnp.logical_or(below_range, above_range) # Avoid NaNs which might propagate to the gradient despite jnp.where # later by setting i to 0 when outside of range. i = jnp.where(outside_range, 0, i) bin_width = x_pos[i + 1] - x_pos[i] bin_height = y_pos[i + 1] - y_pos[i] bin_slope = bin_height / bin_width z = (x - x_pos[i]) / bin_width # `z` should be in range [0, 1] to avoid NaNs later. This can happen because # of small floating point issues or when x is outside of the range and `i` was # set to 0. To avoid all problems, we restrict z in [0, 1]. z = jnp.clip(z, 0., 1.) sq_z = z * z z1mz = z - sq_z # z(1-z) sq_1mz = (1. - z)**2 slopes_term = knot_slopes[i + 1] + knot_slopes[i] - 2. * bin_slope numerator = bin_height * (bin_slope * sq_z + knot_slopes[i] * z1mz) denominator = bin_slope + slopes_term * z1mz y = y_pos[i] + numerator / denominator # Compute log det Jacobian. # The logdet is a sum of 3 logs. It is easy to see that the inputs of the # first two logs are guaranteed to be positive because we ensured that z is in # [0, 1]. This is also true of the log(denominator) because: # denominator # == bin_slope + (knot_slopes[i+1] + knot_slopes[i] - 2 * bin_slope) * z*(1-z) # >= bin_slope - 2 * bin_slope * z * (1-z) # >= bin_slope - 2 * bin_slope * (1/4) # == bin_slope / 2 logdet = 2. * jnp.log(bin_slope) + jnp.log( knot_slopes[i + 1] * sq_z + 2. * bin_slope * z1mz + knot_slopes[i] * sq_1mz) - 2. * jnp.log(denominator) # If x is outside the spline range, we default to a linear transformation. y = jnp.where(below_range, (x - x_pos[0]) * knot_slopes[0] + x_pos[0], y) y = jnp.where(above_range, (x - x_pos[-1]) * knot_slopes[-1] + x_pos[-1], y) logdet = jnp.where(below_range, jnp.log(knot_slopes[0]), logdet) logdet = jnp.where(above_range, jnp.log(knot_slopes[-1]), logdet) return y, logdet
def take(tensor, idx, fill=0): # Non jit-friendly implementation # illegal = jnp.logical_or(idx > p,idx < 1) # return tensor[..., idx-1].at[..., illegal].set(fill) legalized_idx = jnp.clip(idx, a_min=1, a_max=p) illegal_mask = jnp.logical_or(idx > p, idx < 1) return (tensor[..., legalized_idx - 1] * (1 - 1 * illegal_mask) + illegal_mask * fill)
def _randaugment_inner_for_loop(_, in_args): """ Loop body for for randougment. Args: i: loop iteration in_args: loop body arguments Returns: updated loop arguments """ (image, geometric_transforms, random_key, available_ops, op_probs, magnitude, cutout_const, translate_const, join_transforms, default_replace_value) = in_args random_keys = random.split(random_key, num=8) random_key = random_keys[0] # keep for next iteration op_to_select = random.choice(random_keys[1], available_ops, p=op_probs) mask_value = jnp.where(default_replace_value > 0, jnp.ones([image.shape[-1]]) * default_replace_value, random.randint(random_keys[2], [image.shape[-1]], minval=-1, maxval=256)) random_magnitude = random.uniform(random_keys[3], [], minval=0., maxval=magnitude) cutout_mask = color_transforms.get_random_cutout_mask( random_keys[4], image.shape, cutout_const) translate_vals = (random.uniform(random_keys[5], [], minval=0.0, maxval=1.0) * translate_const, random.uniform(random_keys[6], [], minval=0.0, maxval=1.0) * translate_const) negate = random.randint(random_keys[7], [], minval=0, maxval=2).astype('bool') args = level_to_arg(cutout_mask, translate_vals, negate, random_magnitude, mask_value) if DEBUG: print(op_to_select, args[op_to_select]) image, geometric_transform = _apply_ops(image, args, op_to_select) image, geometric_transform = jax.lax.cond( jnp.logical_or(join_transforms, jnp.all( jnp.not_equal(geometric_transform, jnp.identity(4)))), lambda op: (op[0], op[1]), lambda op: (transforms.apply_transform(op[0], op[1], mask_value=mask_value), jnp.identity(4)), (image, geometric_transform) ) geometric_transforms = jnp.matmul(geometric_transforms, geometric_transform) return(image, geometric_transforms, random_key, available_ops, op_probs, magnitude, cutout_const, translate_const, join_transforms, default_replace_value)
def inner_loop_body(i, inputs_and_counter): inputs = inputs_and_counter[0] counter = inputs_and_counter[1] val = inputs[j, i] condition = jnp.logical_or(val < -0.5, val > 0.5) inputs, counter = jax.lax.cond(condition, lambda xTrue: ( jax.ops.index_update(inputs, jax.ops.index[j, i], counter * val), counter * -1), lambda xFalse: (inputs, counter), (None)) return inputs, counter
def apply_cutout(key, image, cutoutwidth, cutoutheight): channels, width, height = image.shape x0, y0 = jax.random.randint(key, (2,), minval=0, maxval=width+1-cutoutwidth) # Construct a mask xx, yy = jnp.meshgrid(jnp.arange(width), jnp.arange(height)) xmask = jnp.where(jnp.logical_and(xx >= x0, xx < x0+cutoutwidth), 0, 1) ymask = jnp.where(jnp.logical_and(yy >= y0, yy < y0+cutoutheight), 0, 1) mask = jnp.logical_or(xmask, ymask) return image * mask
def i_stimulus(t, params_and_data): return lax.cond( np.logical_or( np.less(t, params_and_data["stimulus_start_time"]), np.greater(t, params_and_data["stimulus_end_time"]), ), lambda _: 0.0, lambda _: params_and_data["i_stimulus"], None, )
def categorical_sample(key, probs): """Sample from a set of discrete probabilities.""" probs = probs / probs.sum(axis=-1, keepdims=True) cpi = jnp.cumsum(probs, axis=-1) eps = jnp.finfo(probs.dtype).eps rnds = jax.random.uniform(key=key, shape=probs.shape[:-1] + (1, ), dtype=probs.dtype, minval=eps) return jnp.argmin(jnp.logical_or(rnds > cpi, probs < eps), axis=-1)
def _step( self, graph: "StackedFactorGraph", state_prev: NonlinearSolverState, ) -> NonlinearSolverState: """Linearize, solve linear subproblem, and update on manifold.""" self._hcb_print( lambda i, max_i, cost: f"Iteration #{i}/{max_i}: cost={str(cost)}", i=state_prev.iterations, max_i=self.max_iterations, cost=state_prev.cost, ) # Linearize graph A: sparse.SparseCooMatrix = graph.compute_whitened_residual_jacobian( assignments=state_prev.assignments, residual_vector=state_prev.residual_vector, ) ATb = -(A.T @ state_prev.residual_vector) # Solve linear subproblem local_delta_assignments = VariableAssignments( storage=self.linear_solver.solve_subproblem( A=A, ATb=ATb, lambd=0.0, iteration=state_prev.iterations, ), storage_layout=graph.local_storage_layout, ) # On-manifold retraction assignments = state_prev.assignments.manifold_retract( local_delta_assignments=local_delta_assignments, ) # Check for convergence cost, residual_vector = graph.compute_cost(assignments) done = jnp.logical_or( self.check_exceeded_max_iterations(state_prev=state_prev), self.check_convergence( state_prev=state_prev, cost_updated=cost, local_delta_assignments=local_delta_assignments, negative_gradient=ATb, ), ) return NonlinearSolverState( iterations=state_prev.iterations + 1, assignments=assignments, cost=cost, residual_vector=residual_vector, done=done, )
def _scale_update(update, param): param_norm = jnp.linalg.norm(param) update_norm = jnp.linalg.norm(update) trust_ratio = param_norm / update_norm # Set trust_ratio to 1 in case where parameters would never be updated. zero_norm = jnp.logical_or(param_norm == 0., update_norm == 0.) safe_trust_ratio = jnp.where( zero_norm, jnp.array(1.0, dtype=param.dtype), trust_ratio) return update * safe_trust_ratio
def body_fun(carry): key, samples, _ = carry key, use_key = random.split(key) new_samples = random.randint(use_key, shape=shape, minval=0, maxval=maxval) discard = jnp.logical_or(in1dvec(new_samples, samples), in1dvec(new_samples, rejects)) samples = jnp.where(discard, samples, new_samples) return key, samples, in1dvec(samples, rejects)
def create_mask(x, x_var, indices, val=0.0): mask = onp.ones(x.shape) # mask[x == 0] = val ind = np.where(np.logical_or(x_var <= 0, x_var == mask_std_val**2.0)) mask[ind] = val fullindices = onp.asarray(indices)[:, None] + onp.arange( x.shape[1])[None, :] mask[fullindices < 0] = val # offs = - np.maximum(indices, np.zeros_like(indices)) # for io, off in enumerate(offs): # mask[io, 0:off] = 0 return np.asarray(mask)
def encode_step_fn(carry, x): lstm_state, is_eos = carry new_lstm_state, y = lstm_cell(lstm_state, x) # Pass forward the previous state if EOS has already been reached. def select_carried_state(new_state, old_state): return jnp.where(is_eos[:, np.newaxis], old_state, new_state) # LSTM state is a tuple (c, h). carried_lstm_state = tuple( select_carried_state(*s) for s in zip(new_lstm_state, lstm_state)) # Update `is_eos`. is_eos = jnp.logical_or(is_eos, x[:, eos_id]) return (carried_lstm_state, is_eos), y
def _build_global_mask(mask): """Builds mask for global attention pattern. Args: mask: boolean jax array of shape `[batch_size, seq_len]`. Returns: mask, boolean jax array of shape `[batch_size, 1 (n_heads), seq_len, seq_len]`. """ return jnp.logical_or(mask[:, jnp.newaxis, :, jnp.newaxis], mask[:, jnp.newaxis, jnp.newaxis, :])
def _rational_quadratic_spline_inv(y: Array, x_pos: Array, y_pos: Array, knot_slopes: Array) -> Tuple[Array, Array]: """Applies the inverse of a rational-quadratic spline to a scalar. Args: y: a scalar (0-dimensional array). The scalar `y` can be any real number; it will be transformed by the spline if it's in the closed interval `[y_pos[0], y_pos[-1]]`, and it will be transformed linearly if it's outside that interval. x_pos: array of shape [num_bins + 1], the bin boundaries on the x axis. y_pos: array of shape [num_bins + 1], the bin boundaries on the y axis. knot_slopes: array of shape [num_bins + 1], the slopes at the knot points. Returns: A tuple of two scalars: the output of the inverse transformation and the log of the absolute first derivative of the inverse at `y`. """ # Search to find the right bin. NOTE: The bins are sorted, so we could use # binary search, but this is more GPU/TPU friendly. i = jnp.sum(y_pos < y) - 1 below_range = y <= y_pos[0] above_range = y >= y_pos[-1] outside_range = jnp.logical_or(below_range, above_range) # Set i = 0 when y is out of range, to avoid NaNs later. i = jnp.where(outside_range, 0, i) bin_width = x_pos[i + 1] - x_pos[i] bin_height = y_pos[i + 1] - y_pos[i] bin_slope = bin_height / bin_width w = (y - y_pos[i]) / bin_height w = jnp.clip(w, 0., 1.) # Ensure w is in [0, 1]. # Compute quadratic coefficients: az^2 + bz + c = 0 slopes_term = knot_slopes[i + 1] + knot_slopes[i] - 2. * bin_slope c = -bin_slope * w b = knot_slopes[i] - slopes_term * w a = bin_slope - b # Solve quadratic to obtain z and then x. z = -2. * c / (b + jnp.sqrt(b**2 - 4. * a * c)) z = jnp.clip(z, 0., 1.) # Ensure z is in [0, 1]. x = bin_width * z + x_pos[i] # Compute log det Jacobian. sq_z = z * z z1mz = z - sq_z # z(1-z) sq_1mz = (1. - z)**2 denominator = bin_slope + slopes_term * z1mz logdet = -2. * jnp.log(bin_slope) - jnp.log( knot_slopes[i + 1] * sq_z + 2. * bin_slope * z1mz + knot_slopes[i] * sq_1mz) + 2. * jnp.log(denominator) # If y is outside the spline range, we default to a linear transformation. x = jnp.where(below_range, (y - y_pos[0]) / knot_slopes[0] + y_pos[0], x) x = jnp.where(above_range, (y - y_pos[-1]) / knot_slopes[-1] + y_pos[-1], x) logdet = jnp.where(below_range, -jnp.log(knot_slopes[0]), logdet) logdet = jnp.where(above_range, -jnp.log(knot_slopes[-1]), logdet) return x, logdet