def test_cond_multiple_results(self): def f_jax(pred, x): return lax.cond(pred, lambda t: (t + 1., 1.), lambda f: (f + 2., 2.), x) self.ConvertAndCompare(f_jax, jnp.bool_(True), 1.) self.ConvertAndCompare(f_jax, jnp.bool_(False), 1.)
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: new_scores = jnp.full(scores.shape, -float("inf")) apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1) scores = jnp.where(apply_penalty, new_scores.at[:, self.eos_token_id].set(0), scores) return scores
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: new_scores = jnp.full(scores.shape, -float("inf")) apply_penalty = 1 - jnp.bool_(cur_len - 1) scores = jnp.where( apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.bos_token_id], 0), scores ) return scores
def test_cond(self): def f_jax(pred, x): return lax.cond(pred, lambda t: t + 1., lambda f: f, x) self.ConvertAndCompare(f_jax, jnp.bool_(True), 1.) self.ConvertAndCompare(f_jax, jnp.bool_(False), 1.)