def rejection_stitch_proposal_all(ssm_scenario: StateSpaceModel, x0_all: jnp.ndarray, t: float, x1_all: jnp.ndarray, tplus1: float, x1_log_weight: jnp.ndarray, bound_inflation: float, not_yet_accepted_arr: jnp.ndarray, x1_all_sampled_inds: jnp.ndarray, bound: float, random_keys: jnp.ndarray, rejection_iter: int, num_transition_evals: int) \ -> Tuple[jnp.ndarray, jnp.ndarray, float, jnp.ndarray, int, int]: n = len(x1_all) mapped_tup = map(lambda i: rejection_stitch_proposal_single_cond(not_yet_accepted_arr[i], x1_all_sampled_inds[i], ssm_scenario, x0_all[i], t, x1_all, tplus1, x1_log_weight, bound, random_keys[i]), jnp.arange(n)) x1_all_sampled_inds, dens_evals, not_yet_accepted_arr_new, random_keys = mapped_tup # Check if we need to start again max_dens = jnp.max(dens_evals) reset_bound = max_dens > bound bound = jnp.where(reset_bound, max_dens * bound_inflation, bound) not_yet_accepted_arr_new = jnp.where(reset_bound, jnp.ones(n, dtype='bool'), not_yet_accepted_arr_new) return not_yet_accepted_arr_new, x1_all_sampled_inds, bound, random_keys, rejection_iter + 1, \ num_transition_evals + not_yet_accepted_arr.sum()
def bgrl_loss( first_online_predictions: jnp.ndarray, second_target_projections: jnp.ndarray, second_online_predictions: jnp.ndarray, first_target_projections: jnp.ndarray, symmetrize: bool, valid_mask: jnp.ndarray, ) -> Tuple[jnp.ndarray, LogsDict]: """Implements BGRL loss.""" first_side_node_loss = jnp.sum(jnp.square( _l2_normalize(first_online_predictions, axis=-1) - _l2_normalize(second_target_projections, axis=-1)), axis=-1) if symmetrize: second_side_node_loss = jnp.sum(jnp.square( _l2_normalize(second_online_predictions, axis=-1) - _l2_normalize(first_target_projections, axis=-1)), axis=-1) node_loss = first_side_node_loss + second_side_node_loss else: node_loss = first_side_node_loss loss = (node_loss * valid_mask).sum() / (valid_mask.sum() + 1e-6) return loss, dict(bgrl_loss=loss)
def _mean_with_mask(array: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray: num_valid_rows = mask.sum(0) return sum_with_mask(array, mask) / num_valid_rows