コード例 #1
0
ファイル: loss.py プロジェクト: ismaeelnawaz/objax
def mean_squared_log_error(
    y_true: JaxArray,
    y_pred: JaxArray,
    keep_axis: Optional[Iterable[int]] = (0, )) -> JaxArray:
    """Computes the mean squared logarithmic error between y_true and y_pred.

    Args:
        y_true: a tensor of shape (d0, .. dN-1).
        y_pred: a tensor of shape (d0, .. dN-1).
        keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value.

    Returns:
        tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error.
    """
    loss = (jn.log1p(y_true) - jn.log1p(y_pred))**2
    axis = [i for i in range(loss.ndim) if i not in (keep_axis or ())]
    return loss.mean(axis)
コード例 #2
0
 def to_exp(self) -> LogarithmicEP:
     probability = jnp.exp(self.log_probability)
     chi = jnp.where(
         self.log_probability < -50.0, 1.0,
         jnp.where(
             self.log_probability > -1e-7, jnp.inf, probability /
             (jnp.expm1(self.log_probability) * jnp.log1p(-probability))))
     return LogarithmicEP(chi)
コード例 #3
0
 def log_prob(self, value):
     if self._validate_args:
         self._validate_sample(value)
     low = (self.low - self.loc) / self.scale
     # pi / 2 is arctan of self.high when that arg is supported
     normalize_term = np.log(np.pi / 2 - np.arctan(low)) + np.log(
         self.scale)
     return -np.log1p(((value - self.loc) / self.scale)**2) - normalize_term
コード例 #4
0
def sigmoid_cross_entropy_with_logits(*, labels: Array,
                                      logits: Array) -> Array:
    """Sigmoid cross entropy loss."""
    zeros = jnp.zeros_like(logits, dtype=logits.dtype)
    condition = (logits >= zeros)
    relu_logits = jnp.where(condition, logits, zeros)
    neg_abs_logits = jnp.where(condition, -logits, logits)
    return relu_logits - logits * labels + jnp.log1p(jnp.exp(neg_abs_logits))
コード例 #5
0
 def log_prob(self, value):
     post_value = self.concentration + value
     return (
         -betaln(self.concentration, value + 1)
         - jnp.log(post_value)
         + self.concentration * jnp.log(self.rate)
         - post_value * jnp.log1p(self.rate)
     )
コード例 #6
0
ファイル: utils.py プロジェクト: fehiepsi/jaxns
def signed_logaddexp(log_abs_val1, sign1, log_abs_val2, sign2):
    amax = jnp.maximum(log_abs_val1, log_abs_val2)
    signmax = jnp.where(log_abs_val1 > log_abs_val2, sign1, sign2)
    delta = -jnp.abs(log_abs_val2 - log_abs_val1)#nan iff inf - inf
    sign = sign1*sign2
    return jnp.where(jnp.isnan(delta),
                      log_abs_val1 + log_abs_val2,  # NaNs or infinities of the same sign.
                      amax + jnp.log1p(sign * jnp.exp(delta))), signmax
コード例 #7
0
ファイル: logit.py プロジェクト: IPL-UV/rbig_jax
    def forward_and_log_det(self, inputs: Array, **kwargs) -> Tuple[Array, Array]:

        inputs = jnp.clip(inputs, EPS, 1 - EPS)

        outputs = (1.0 / self.temperature) * (safe_log(inputs) - jnp.log1p(-inputs))

        logabsdet = -self.inverse_log_det_jacobian(outputs, **kwargs)

        return outputs, logabsdet
コード例 #8
0
def lossfun(x, alpha, scale):
    r"""Implements the general form of the loss.

  This implements the rho(x, \alpha, c) function described in "A General and
  Adaptive Robust Loss Function", Jonathan T. Barron,
  https://arxiv.org/abs/1701.03077.

  Args:
    x: The residual for which the loss is being computed. x can have any shape,
      and alpha and scale will be broadcasted to match x's shape if necessary.
    alpha: The shape parameter of the loss (\alpha in the paper), where more
      negative values produce a loss with more robust behavior (outliers "cost"
      less), and more positive values produce a loss with less robust behavior
      (outliers are penalized more heavily). Alpha can be any value in
      [-infinity, infinity], but the gradient of the loss with respect to alpha
      is 0 at -infinity, infinity, 0, and 2. Varying alpha allows for smooth
      interpolation between several discrete robust losses:
        alpha=-Infinity: Welsch/Leclerc Loss.
        alpha=-2: Geman-McClure loss.
        alpha=0: Cauchy/Lortentzian loss.
        alpha=1: Charbonnier/pseudo-Huber loss.
        alpha=2: L2 loss.
    scale: The scale parameter of the loss. When |x| < scale, the loss is an
      L2-like quadratic bowl, and when |x| > scale the loss function takes on a
      different shape according to alpha.

  Returns:
    The losses for each element of x, in the same shape as x.
  """
    eps = jnp.finfo(jnp.float32).eps

    # `scale` must be > 0.
    scale = jnp.maximum(eps, scale)

    # The loss when alpha == 2. This will get reused repeatedly.
    loss_two = 0.5 * (x / scale)**2

    # "Safe" versions of log1p and expm1 that will not NaN-out.
    log1p_safe = lambda x: jnp.log1p(jnp.minimum(x, 3e37))
    expm1_safe = lambda x: jnp.expm1(jnp.minimum(x, 87.5))

    # The loss when not in one of the special casess.
    # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by.
    a = jnp.where(alpha >= 0, jnp.ones_like(alpha),
                  -jnp.ones_like(alpha)) * jnp.maximum(eps, jnp.abs(alpha))
    # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by.
    b = jnp.maximum(eps, jnp.abs(alpha - 2))
    loss_ow = (b / a) * ((loss_two / (0.5 * b) + 1)**(0.5 * alpha) - 1)

    # Select which of the cases of the loss to return as a function of alpha.
    return jnp.where(
        alpha == -jnp.inf, -expm1_safe(-loss_two),
        jnp.where(
            alpha == 0, log1p_safe(loss_two),
            jnp.where(
                alpha == 2, loss_two,
                jnp.where(alpha == jnp.inf, expm1_safe(loss_two), loss_ow))))
コード例 #9
0
ファイル: logit.py プロジェクト: alexhepburn/rbig_jax
    def __call__(self, inputs):

        inputs = np.clip(inputs, self.eps.value, 1 - self.eps.value)

        outputs = (1 / self.temperature.value) * (np.log(inputs) -
                                                  np.log1p(-inputs))
        logabsdet = -(np.log(self.temperature.value) -
                      softplus(-self.temperature.value * outputs) -
                      softplus(self.temperature.value * outputs))
        return outputs, logabsdet.sum(axis=1)
コード例 #10
0
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        if self.base_dist._validate_args:
            self.base_dist._validate_sample(value)

        # Eq. 2.1 in [1]
        skew_prob = jnp.log1p((self.skewness * jnp.sin(
            (value - self.base_dist.mean) % (2 * jnp.pi))).sum(-1))
        return self.base_dist.log_prob(value) + skew_prob
コード例 #11
0
ファイル: utils.py プロジェクト: fehiepsi/jaxns
def logaddexp(x1, x2):
    if is_complex(x1) or is_complex(x2):
        select1 = x1.real > x2.real
        amax = jnp.where(select1, x1, x2)
        delta = jnp.where(select1, x2-x1, x1-x2)
        return jnp.where(jnp.isnan(delta),
                          x1+x2,  # NaNs or infinities of the same sign.
                          amax + jnp.log1p(jnp.exp(delta)))
    else:
        return jnp.logaddexp(x1, x2)
コード例 #12
0
def log_expbig_minus_expsmall(big: Array, small: Array) -> Array:
    """Stable implementation of `log(exp(big) - exp(small))`.

  Args:
    big: First input.
    small: Second input. It must be `small <= big`.

  Returns:
    The resulting `log(exp(big) - exp(small))`.
  """
    return big + jnp.log1p(-jnp.exp(small - big))
コード例 #13
0
 def _sample_n(self, key: PRNGKey, n: int) -> Array:
     """See `Distribution._sample_n`."""
     out_shape = (n, ) + self.batch_shape
     dtype = jnp.result_type(self._loc, self._scale)
     uniform = jax.random.uniform(key,
                                  shape=out_shape,
                                  dtype=dtype,
                                  minval=jnp.finfo(dtype).tiny,
                                  maxval=1.)
     rnd = jnp.log(uniform) - jnp.log1p(-uniform)
     return self._scale * rnd + self._loc
コード例 #14
0
def plot_gradients(loss_history, opt_state, get_params, net_params, net_apply):
    # plot loss
    clear_output(True)
    plt.figure(figsize=[16, 8])
    plt.subplot(1, 2, 1)
    plt.title("mean loss = %.3f" % np.mean(np.array(loss_history[-32:])))
    plt.scatter(np.arange(len(loss_history)), loss_history)
    plt.grid()

    # plot gradient vectors
    plt.subplot(1, 2, 2)
    net_params = get_params(opt_state)
    xx = np.stack(np.meshgrid(np.linspace(-1.5, 2.0, 50),
                              np.linspace(-1.5, 2.0, 50)),
                  axis=-1).reshape(-1, 2)
    scores = net_apply(net_params, xx)
    scores_norm = np.linalg.norm(scores, axis=-1, ord=2, keepdims=True)
    scores_log1p = scores / (scores_norm + 1e-9) * np.log1p(scores_norm)

    clear_output(True)

    plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')
    plt.xlim(-1.5, 2.0)
    plt.ylim(-1.5, 2.0)
    plt.show()

    print("displaying gradients...")
    plt.figure(figsize=[16, 16])

    net_params = get_params(opt_state)
    xx = np.stack(np.meshgrid(np.linspace(-1.5, 1.5, 50),
                              np.linspace(-1.5, 1.5, 50)),
                  axis=-1).reshape(-1, 2)
    scores = net_apply(net_params, xx)
    scores_norm = np.linalg.norm(scores, axis=-1, ord=2, keepdims=True)
    scores_log1p = scores / (scores_norm + 1e-9) * np.log1p(scores_norm)

    plt.quiver(*xx.T, *scores_log1p.T, width=0.002, color='green')
    plt.scatter(*sample_batch(10_000).T, alpha=0.25)
    plt.show()
コード例 #15
0
ファイル: loss.py プロジェクト: google/objax
def mean_squared_log_error(
    y_true: JaxArray,
    y_pred: JaxArray,
    keep_axis: Optional[Iterable[int]] = (0, )) -> JaxArray:
    """Computes the mean squared logarithmic error between y_true and y_pred.

    Args:
        y_true: a tensor of shape (d0, .. dN-1).
        y_pred: a tensor of shape (d0, .. dN-1).
        keep_axis: a sequence of the dimensions to keep, use `None` to return a scalar value.

    Returns:
        tensor of shape (d_i, ..., for i in keep_axis) containing the mean squared error.
    """

    FUNC_NAME = 'mean_squared_log_error'
    if y_true.shape != y_pred.shape:
        warnings.warn(' {} {} : arg1 {} and arg2 {}'.format(
            WARN_SHAPE_MISMATCH, FUNC_NAME, y_true.shape, y_pred.shape))
    loss = (jn.log1p(y_true) - jn.log1p(y_pred))**2
    axis = [i for i in range(loss.ndim) if i not in (keep_axis or ())]
    return loss.mean(axis)
コード例 #16
0
ファイル: jax.py プロジェクト: bmorris3/kelp
def I(alpha, Phi):
    """
    Equation 39
    """
    cos_alpha = jnp.cos(alpha)
    cos_alpha_2 = jnp.cos(alpha / 2)

    z = jnp.sin(alpha / 2 - Phi / 2) / jnp.cos(Phi / 2)

    # The following expression has the same behavior
    # as I_0 = jnp.arctanh(z), but it doesn't blow up at alpha=0
    I_0 = jnp.where(jnp.abs(z) < 1.0, 0.5 * (jnp.log1p(z) - jnp.log1p(-z)), 0)

    I_S = (-1 / (2 * cos_alpha_2) * (jnp.sin(alpha / 2 - Phi) +
                                     (cos_alpha - 1) * I_0))
    I_L = 1 / np.pi * (Phi * cos_alpha - 0.5 * jnp.sin(alpha - 2 * Phi))
    I_C = -1 / (24 * cos_alpha_2) * (
        -3 * jnp.sin(alpha / 2 - Phi) + jnp.sin(3 * alpha / 2 - 3 * Phi) +
        6 * jnp.sin(3 * alpha / 2 - Phi) - 6 * jnp.sin(alpha / 2 + Phi) +
        24 * jnp.sin(alpha / 2)**4 * I_0)

    return I_S, I_L, I_C
コード例 #17
0
ファイル: nonlinearities.py プロジェクト: jxzhangjhu/NuX
  def call(self,
           inputs: Mapping[str, jnp.ndarray],
           rng: jnp.ndarray=None,
           sample: Optional[bool]=False,
           **kwargs
  ) -> Mapping[str, jnp.ndarray]:
    x_shape = self.get_unbatched_shapes(sample)["x"]
    sum_axes = util.last_axes(x_shape)

    if sample == False:
      x = inputs["x"]
      x = jnp.where(x < 0.0, 1e-5, x)
      dx = jnp.log1p(-jnp.exp(-x))
      z = x + dx
      log_det = -dx.sum(axis=sum_axes)*jnp.ones(self.batch_shape)
      outputs = {"x": z, "log_det": log_det}
    else:
      x = jax.nn.softplus(inputs["x"])
      log_det = -jnp.log1p(-jnp.exp(x)).sum(axis=sum_axes)*jnp.ones(self.batch_shape)
      outputs = {"x": x, "log_det": log_det}

    return outputs
コード例 #18
0
    def lbp_loop(_, alphabeta):
        alpha = alphabeta[0, :, :]
        beta = alphabeta[1, :, :]
        # update alpha
        beta_bar = np.sum(beta, axis=0)
        alpha = jax.nn.log_sigmoid(beta_bar - beta + mu)
        alpha *= groups

        # update beta
        alpha_bar = np.sum(alpha, axis=1, keepdims=True)
        beta = np.log1p(test_sign *
                        np.exp(-alpha + alpha_bar + gamma[:, np.newaxis]))
        beta *= groups
        return np.stack((alpha, beta), axis=0)
コード例 #19
0
def update_copula_single(logpmf1, log_v, y_new, logalpha, rho):
    eps = 5e-5
    logpmf1 = jnp.clip(logpmf1, jnp.log(eps),
                       jnp.log(1 - eps))  #clip u before passing to bicop
    log_v = jnp.clip(log_v, jnp.log(eps),
                     jnp.log(1 - eps))  #clip u before passing to bicop

    log1alpha = jnp.log1p(jnp.clip(-jnp.exp(logalpha), -1 + eps, jnp.inf))
    log1_v = jnp.log1p(jnp.clip(-jnp.exp(log_v), -1 + eps, jnp.inf))

    min_logu1v1 = jnp.min(jnp.array([logpmf1, log_v]))

    ##Bernoulli update
    frac = y_new * jnp.exp(min_logu1v1 - logpmf1 - log_v) + (1 - y_new) * (
        1 / jnp.exp(log1_v) - jnp.exp(min_logu1v1 - logpmf1 - log1_v)
    )  #make this more accurate?
    kyy_ = 1 - rho + rho * frac
    kyy_ = jnp.clip(kyy_, eps, jnp.inf)

    logkyy_ = jnp.log(kyy_)
    logpmf1_new = jnp.logaddexp(log1alpha, (logalpha + logkyy_)) + logpmf1

    return logpmf1_new
コード例 #20
0
def norm_logbicop_diag_approx(log_u,rho):
    eps = 1e-6
    log_u = jnp.clip(log_u,jnp.log(eps),jnp.log(1-eps))
    ind_true = jnp.where(log_u<=jnp.log(0.5),x = 1,y = 0) #check if u <0.5
    log_u = ind_true*log_u + (1-ind_true)*jnp.log1p(-jnp.exp(log_u)) #replaces log(u) with log(1-u) if less than 0.5

    u = jnp.exp(log_u)
    log_g = log_g_cop(u,rho) #for u<0.5
    log_interp = jnp.log((1+(rho/2)+(1/jnp.pi)*jnp.arcsin(rho)) + u*((2/jnp.pi)*jnp.arcsin(rho)- rho))
    logbicop = log_u + log_g +log_interp

    #add 2u-1 if u >0.5
    logbicop = jnp.log(ind_true*jnp.exp(logbicop)+ (1-ind_true)*((1-2*u)+jnp.exp(logbicop)))

    return logbicop
コード例 #21
0
    def log_prob(self, data, **kwargs):
        loc, scale, df, dim = \
            self.loc, self.scale, self.df, self.dimension
        assert data.ndim == 2 and data.shape[1] == dim

        # Quadratic term
        tmp = np.linalg.solve(scale, (data - loc).T).T
        lp = -0.5 * (df + dim) * np.log1p(np.sum(tmp**2, axis=1) / df)

        # Normalizer
        lp += spsp.gammaln(0.5 * (df + dim)) - spsp.gammaln(0.5 * df)
        lp += -0.5 * dim * np.log(np.pi) - 0.5 * dim * np.log(df)
        # L_diag = np.reshape(Ls, Ls.shape[:-2] + (-1,))[..., ::D + 1]
        lp += -np.sum(np.log(np.diag(scale)))
        return lp
コード例 #22
0
    def log_abs_det_jacobian(self, x, y, intermediates=None):
        # compute stick-breaking logdet
        #   t1 -> t1
        #   t2 -> t2 * (1 - abs(t1))
        #   t3 -> t3 * (1 - abs(t1)) * (1 - abs(t2))
        # hence jacobian is triangular and logdet is the sum of the log
        # of the diagonal part of the jacobian
        one_minus_remainder = jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1)
        eps = jnp.finfo(y.dtype).eps
        one_minus_remainder = jnp.clip(one_minus_remainder, a_max=1 - eps)
        # log(remainder) = log1p(remainder - 1)
        stick_breaking_logdet = jnp.sum(jnp.log1p(-one_minus_remainder), axis=-1)

        tanh_logdet = -2 * jnp.sum(x + softplus(-2 * x) - jnp.log(2.0), axis=-1)
        return stick_breaking_logdet + tanh_logdet
コード例 #23
0
def _probs_and_log_probs(
    dist: Union[Bernoulli,
                tfd.Bernoulli]) -> Tuple[Array, Array, Array, Array]:
    """Calculates both `probs` and `log_probs`."""
    # pylint: disable=protected-access
    if dist._logits is None:
        probs0 = 1 - dist._probs
        probs1 = dist._probs
        log_probs0 = jnp.log1p(-1. * dist._probs)
        log_probs1 = jnp.log(dist._probs)
    else:
        probs0 = jax.nn.sigmoid(-1. * dist._logits)
        probs1 = jax.nn.sigmoid(dist._logits)
        log_probs0 = -jax.nn.softplus(dist._logits)
        log_probs1 = -jax.nn.softplus(-1. * dist._logits)
    return probs0, probs1, log_probs0, log_probs1
コード例 #24
0
ファイル: regressions.py プロジェクト: lindermanlab/jxf
    def log_prob(self, data, covariates):
        df, scale = self.df, self.scale
        dim = self.weights.shape[-2]
        predictions = covariates @ self.weights.T

        # Quadratic term
        tmp = np.linalg.solve(scale, (data - predictions).T).T
        lp = -0.5 * (df + dim) * np.log1p(np.sum(tmp**2, axis=1) / df)

        # Normalizer
        lp += spsp.gammaln(0.5 * (df + dim)) - spsp.gammaln(0.5 * df)
        lp += -0.5 * dim * np.log(np.pi) - 0.5 * dim * np.log(df)
        scale_diag = np.reshape(scale,
                                scale.shape[:-2] + (-1, ))[..., ::dim + 1]
        lp += -np.sum(np.log(scale_diag), axis=-1).reshape(scale.shape[:-2])
        return lp
コード例 #25
0
def _binomial_inversion(key, p, n):
    def _binom_inv_body_fn(val):
        i, key, geom_acc = val
        key, key_u = random.split(key)
        u = random.uniform(key_u)
        geom = np.floor(np.log1p(-u) / log1_p) + 1
        geom_acc = geom_acc + geom
        return i + 1, key, geom_acc

    def _binom_inv_cond_fn(val):
        i, _, geom_acc = val
        return geom_acc <= n

    log1_p = np.log1p(-p)
    ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.))
    return ret[0]
コード例 #26
0
def _get_tr_params(n, p):
    # See Table 1. Additionally, we pre-compute log(p), log1(-p) and the
    # constant terms, that depend only on (n, p, m) in log(f(k)) (bottom of page 5).
    mu = n * p
    spq = np.sqrt(mu * (1 - p))
    c = mu + 0.5
    b = 1.15 + 2.53 * spq
    a = -0.0873 + 0.0248 * b + 0.01 * p
    alpha = (2.83 + 5.1 / b) * spq
    u_r = 0.43
    v_r = 0.92 - 4.2 / b
    m = np.floor((n + 1) * p).astype(n.dtype)
    log_p = np.log(p)
    log1_p = np.log1p(-p)
    log_h = (m + 0.5) * (np.log((m + 1.) / (n - m + 1.)) + log1_p - log_p) + \
            (stirling_approx_tail(m) + stirling_approx_tail(n - m))
    return _tr_params(c, b, a, alpha, u_r, v_r, m, log_p, log1_p, log_h)
コード例 #27
0
def compute_logalpha_x_single(x_plot, xn, rho_x):
    d = jnp.shape(xn)[1]
    n = jnp.shape(xn)[0]

    #compute alpha_n
    n_range = jnp.arange(n) + 1
    logalpha_seq = (jnp.log(2. - (1 / (n_range))) - jnp.log(n_range + 1))
    log1alpha_seq = jnp.log1p(-jnp.exp(logalpha_seq))

    #compute cop_dens
    logk_xx = mvcr.calc_logkxx_test(x_plot.reshape(1, d), xn.reshape(-1, d),
                                    rho_x)[:, 0]

    #compute alpha_x
    logalpha_x = (logalpha_seq + logk_xx) - jnp.logaddexp(
        log1alpha_seq, (logalpha_seq + logk_xx))
    return logalpha_x
コード例 #28
0
    def log_prob(self, value):
        phi_D = value * self.D_sparse
        phi_W = jnp.zeros(value.shape[0])
        phi_W1 = index_add(
            phi_W,
            self.W_sparse[:, 0],
            phi_W[self.W_sparse[:, 0]] + value[self.W_sparse[:, 1]],
        )
        phi_W2 = index_add(
            phi_W1,
            self.W_sparse[:, 1],
            phi_W[self.W_sparse[:, 1]] + value[self.W_sparse[:, 0]],
        )
        ldet_terms = jnp.log1p(-self.alpha * self.eigenvals)

        return 0.5 * (
            value.shape[0] * jnp.log(self.tau)
            + jnp.sum(ldet_terms)
            - self.tau * (phi_D @ value - self.alpha * (phi_W2 @ value))
        )
コード例 #29
0
ファイル: utils.py プロジェクト: zeta1999/optax
def weibull_min(key, scale, concentration, shape=(), dtype=jnp.float32):
  """Sample from a Weibull distribution.

  The scipy counterpart is `scipy.stats.weibull_min`.

  Args:
    key: a PRNGKey key.
    scale: The scale parameter of the distribution.
    concentration: The concentration parameter of the distribution.
    shape: The shape added to the parameters loc and scale broadcastable shape.
    dtype: The type used for samples.

  Returns:
    A jnp.array of samples.

  """
  random_uniform = jax.random.uniform(
      key=key, shape=shape, minval=0, maxval=1, dtype=dtype)

  # Inverse weibull CDF.
  return jnp.power(-jnp.log1p(-random_uniform), 1.0/concentration) * scale
コード例 #30
0
def _load_data(num_seasons: int = 100,
               batch: int = 1,
               x_dim: int = 1) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Load sequential data with seasonality and trend."""

    t = jnp.sin(jnp.arange(0, 6 * jnp.pi, step=6 * jnp.pi / 700))[:, None,
                                                                  None]

    x = dist.Poisson(100).sample(random.PRNGKey(1234),
                                 (7 * num_seasons, batch, x_dim))
    x += jnp.array(np.random.rand(7 * num_seasons).cumsum(0)[:, None, None])
    x += jnp.array(([50] * 5 + [1] * 2) * num_seasons)[:, None, None]
    x = jnp.log1p(x)
    x += t * 2

    assert isinstance(x, jnp.ndarray)
    assert isinstance(t, jnp.ndarray)
    assert x.shape[0] == t.shape[0]
    assert x.shape[1] == t.shape[1]

    return x, t