Example #1
0
def job(
    random_seed: int,
    base_dir: Path,
    theta_min: float,
    theta_max: float,
    theta_dot_min: float,
    theta_dot_max: float,
):
    rng = random.PRNGKey(random_seed)

    rng, train_rng = random.split(rng)
    callback_rngs = random.split(rng, num_episodes)

    params = [None]
    tracking_params = [None]

    train_reward_per_episode = []
    policy_value_per_episode = []
    episode_lengths = []
    elapsed_per_episode = []

    def callback(info):
        episode = info['episode']
        params[0] = info["optimizer"].value
        tracking_params[0] = info["tracking_params"]

        policy_value = run_ddpg.eval_policy(callback_rngs[episode],
                                            info["optimizer"].value[0])

        train_reward_per_episode.append(info['reward'])
        policy_value_per_episode.append(policy_value)
        episode_lengths.append(info["episode_length"])
        elapsed_per_episode.append(info["elapsed"])

    run_ddpg.train(
        train_rng,
        num_episodes,
        lambda t, s: lax.bitwise_or(
            lax.ge(t, config.episode_length),
            lax.bitwise_or(
                lax.le(s[0], theta_min),
                lax.bitwise_or(
                    lax.ge(s[0], theta_max),
                    lax.bitwise_or(lax.le(s[1], theta_dot_min),
                                   lax.ge(s[1], theta_dot_max))))),
        callback,
    )
    with (base_dir / f"seed={random_seed}.pkl").open(mode="wb") as f:
        pickle.dump(
            {
                "final_params": params[0],
                "final_tracking_params": tracking_params[0],
                "train_reward_per_episode": train_reward_per_episode,
                "policy_value_per_episode": policy_value_per_episode,
                "episode_lengths": episode_lengths,
                "elapsed_per_episode": elapsed_per_episode,
            }, f)
Example #2
0
def nanvar(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanvar", a)
    lax_internal._check_user_dtype_supported(dtype, "nanvar")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanvar is not supported.")

    a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where)

    centered = _where(lax_internal._isnan(a), 0,
                      a - a_mean)  # double-where trick for gradients.
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
                     axis=axis,
                     keepdims=keepdims,
                     where=where)
    normalizer = normalizer - ddof
    normalizer_mask = lax.le(normalizer, 0)
    result = sum(centered, axis, keepdims=keepdims, where=where)
    result = _where(normalizer_mask, np.nan, result)
    divisor = _where(normalizer_mask, 1, normalizer)
    out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
    return lax.convert_element_type(out, dtype)
Example #3
0
def logpmf(k, p, loc=0):
    k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc)
    zero = lax._const(k, 0)
    one = lax._const(k, 1)
    x = lax.sub(k, loc)
    log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p)
    return jnp.where(lax.le(x, zero), -jnp.inf, log_probs)
Example #4
0
def cdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale)
    half = _constant_like(x, 0.5)
    one = _constant_like(x, 1)
    zero = _constant_like(x, 0)
    diff = lax.div(lax.sub(x, loc), scale)
    return lax.select(lax.le(diff, zero), lax.mul(half, lax.exp(diff)),
                      lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
Example #5
0
 def cond_fun(state):
     i, j, _ = state
     return lax.le(j, i)
Example #6
0
def isclose(a, b, rtol=1e-05, atol=1e-08):
    a, b = _promote_args("isclose", a, b)
    rtol = lax.convert_element_type(rtol, _dtype(a))
    atol = lax.convert_element_type(atol, _dtype(a))
    return lax.le(lax.abs(lax.sub(a, b)),
                  lax.add(atol, lax.mul(rtol, lax.abs(b))))