def test_owner_equality(self): a = RRef(40) b = RRef(50) other_rank = (self.rank + 1) % self.world_size other_a = rpc.remote("worker{}".format(other_rank), torch.add, args=(torch.ones(1), 1)) other_b = rpc.remote("worker{}".format(other_rank), torch.add, args=(torch.ones(1), 1)) other_a.to_here() # to ensure clean termination other_b.to_here() self.assertNotEqual(a.owner(), 23) self.assertEqual(other_a.owner(), other_b.owner()) self.assertNotEqual(a.owner(), other_a.owner()) self.assertEqual(other_a.owner(), other_a.owner()) self.assertEqual(other_a.owner(), other_b.owner()) self.assertEqual(a.owner(), a.owner()) self.assertEqual(a.owner(), b.owner()) self.assertEqual(a.owner(), rpc.get_worker_info()) x = dict() x[a.owner()] = a x[other_a.owner()] = other_a self.assertEqual(x[a.owner()], a) self.assertEqual(x[b.owner()], a) self.assertEqual(x[other_a.owner()], other_a) self.assertEqual(x[other_b.owner()], other_a) self.assertEqual(len(x), 2)
def dloss(input_rref: rpc.RRef, target_rref: rpc.RRef) -> rpc.RRef: return rpc.remote(input_rref.owner(), _rloss, args=(loss_func, input_rref, target_rref))