def _batch_to(self, batch: Tensor) -> Tensor:
     if torch.is_floating_point(batch):
         if self.precision == PrecisionType.HALF:
             return batch.half()
         elif self.precision == PrecisionType.BFLOAT:
             return batch.bfloat16()
     return batch