def rloss(loss_func: Callable, input_rref: rpc.RRef,
          target_rref: rpc.RRef) -> rpc.RRef:
    if BOUNCE_TENSORS:
        return loss_func(input_rref.remote().cpu().to_here(),
                         target_rref.remote().cpu().to_here())
    else:
        return loss_func(input_rref.to_here(), target_rref.to_here())
 def forward(self, x_rref: rpc.RRef) -> Tensor:  # type: ignore
     if BOUNCE_TENSORS:
         return x_rref.remote().cpu().to_here().to(self.device)
     else:
         return x_rref.to_here().to(self.device)