def rloss(loss_func: Callable, input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef: if BOUNCE_TENSORS: return loss_func(input_rref.remote().cpu().to_here(), target_rref.remote().cpu().to_here()) else: return loss_func(input_rref.to_here(), target_rref.to_here())
def _rloss(loss_func: Callable, input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef: return loss_func(input_rref.to_here(), target_rref.to_here())
def forward(self, x_rref: rpc.RRef) -> Tensor: # type: ignore if BOUNCE_TENSORS: return x_rref.remote().cpu().to_here().to(self.device) else: return x_rref.to_here().to(self.device)
def _parameter_rrefs(module: rpc.RRef) -> List[rpc.RRef]: return [rpc.RRef(p) for p in module.to_here().parameters()]
def forward(self, x_rref: rpc.RRef) -> Tensor: # type: ignore return x_rref.to_here()