예제 #1
0
    def test_cond_multiple_results(self):
        def f_jax(pred, x):
            return lax.cond(pred, lambda t: (t + 1., 1.), lambda f:
                            (f + 2., 2.), x)

        self.ConvertAndCompare(f_jax, jnp.bool_(True), 1.)
        self.ConvertAndCompare(f_jax, jnp.bool_(False), 1.)
예제 #2
0
    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray,
                 cur_len: int) -> jnp.ndarray:
        new_scores = jnp.full(scores.shape, -float("inf"))

        apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)

        scores = jnp.where(apply_penalty,
                           new_scores.at[:, self.eos_token_id].set(0), scores)

        return scores
예제 #3
0
    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
        new_scores = jnp.full(scores.shape, -float("inf"))

        apply_penalty = 1 - jnp.bool_(cur_len - 1)

        scores = jnp.where(
            apply_penalty, jax.ops.index_update(new_scores, jax.ops.index[:, self.bos_token_id], 0), scores
        )

        return scores
예제 #4
0
    def test_cond(self):
        def f_jax(pred, x):
            return lax.cond(pred, lambda t: t + 1., lambda f: f, x)

        self.ConvertAndCompare(f_jax, jnp.bool_(True), 1.)
        self.ConvertAndCompare(f_jax, jnp.bool_(False), 1.)