コード例 #1
0
ファイル: online_smoothing.py プロジェクト: SamDuffield/mocat
def rejection_stitch_proposal_all(ssm_scenario: StateSpaceModel,
                                  x0_all: jnp.ndarray,
                                  t: float,
                                  x1_all: jnp.ndarray,
                                  tplus1: float,
                                  x1_log_weight: jnp.ndarray,
                                  bound_inflation: float,
                                  not_yet_accepted_arr: jnp.ndarray,
                                  x1_all_sampled_inds: jnp.ndarray,
                                  bound: float,
                                  random_keys: jnp.ndarray,
                                  rejection_iter: int,
                                  num_transition_evals: int) \
        -> Tuple[jnp.ndarray, jnp.ndarray, float, jnp.ndarray, int, int]:
    n = len(x1_all)
    mapped_tup = map(lambda i: rejection_stitch_proposal_single_cond(not_yet_accepted_arr[i],
                                                                     x1_all_sampled_inds[i],
                                                                     ssm_scenario,
                                                                     x0_all[i],
                                                                     t,
                                                                     x1_all,
                                                                     tplus1,
                                                                     x1_log_weight,
                                                                     bound,
                                                                     random_keys[i]), jnp.arange(n))
    x1_all_sampled_inds, dens_evals, not_yet_accepted_arr_new, random_keys = mapped_tup

    # Check if we need to start again
    max_dens = jnp.max(dens_evals)
    reset_bound = max_dens > bound
    bound = jnp.where(reset_bound, max_dens * bound_inflation, bound)
    not_yet_accepted_arr_new = jnp.where(reset_bound, jnp.ones(n, dtype='bool'), not_yet_accepted_arr_new)
    return not_yet_accepted_arr_new, x1_all_sampled_inds, bound, random_keys, rejection_iter + 1, \
           num_transition_evals + not_yet_accepted_arr.sum()
コード例 #2
0
ファイル: losses.py プロジェクト: yynst2/deepmind-research
def bgrl_loss(
    first_online_predictions: jnp.ndarray,
    second_target_projections: jnp.ndarray,
    second_online_predictions: jnp.ndarray,
    first_target_projections: jnp.ndarray,
    symmetrize: bool,
    valid_mask: jnp.ndarray,
) -> Tuple[jnp.ndarray, LogsDict]:
    """Implements BGRL loss."""
    first_side_node_loss = jnp.sum(jnp.square(
        _l2_normalize(first_online_predictions, axis=-1) -
        _l2_normalize(second_target_projections, axis=-1)),
                                   axis=-1)
    if symmetrize:
        second_side_node_loss = jnp.sum(jnp.square(
            _l2_normalize(second_online_predictions, axis=-1) -
            _l2_normalize(first_target_projections, axis=-1)),
                                        axis=-1)
        node_loss = first_side_node_loss + second_side_node_loss
    else:
        node_loss = first_side_node_loss
    loss = (node_loss * valid_mask).sum() / (valid_mask.sum() + 1e-6)
    return loss, dict(bgrl_loss=loss)
コード例 #3
0
def _mean_with_mask(array: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray:
    num_valid_rows = mask.sum(0)
    return sum_with_mask(array, mask) / num_valid_rows