Example #1
0
 def cond_fn(position, ids, *unused_states):
   """Should we run another loop iteration."""
   past_end = mtf.greater_equal(position, length_dim.size)
   is_done = past_end
   if stop_at_token is not None:
     has_eos = mtf.reduce_any(
         mtf.equal(ids, stop_at_token), reduced_dim=length_dim)
     is_done = mtf.logical_or(is_done, has_eos)
   all_done = mtf.reduce_all(is_done)
   return mtf.logical_not(all_done)
Example #2
0
    def cond_fn(position, ids, *unused_states):
        """Should we run another loop iteration?"""
        past_end = mtf.greater_equal(position, length_dim.size)
        if max_steps:
            past_end = mtf.logical_or(
                past_end, mtf.greater_equal(position - initial_position, max_steps))

        is_done = past_end
        if stop_at_token is not None:
            eos_count = mtf.reduce_sum(
                mtf.to_int32(mtf.equal(ids, stop_at_token)),
                reduced_dim=length_dim)
            has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
            is_done = mtf.logical_or(is_done, has_additional_eos)
        all_done = mtf.reduce_all(is_done)
        return mtf.logical_not(all_done)