def test_plus_is_minus_variable_local(self): x = sy.Variable(torch.FloatTensor([5, 6])) y = sy.Variable(torch.FloatTensor([3, 4])) x = sy._PlusIsMinusTensor().on(x) y = sy._PlusIsMinusTensor().on(y) display = 'Variable > _PlusIsMinusTensor > _LocalTensor\n' \ ' - FloatTensor > _PlusIsMinusTensor > _LocalTensor\n' \ ' - - Variable > _PlusIsMinusTensor > _LocalTensor\n' \ ' - FloatTensor > _PlusIsMinusTensor > _LocalTensor' assert torch_utils.chain_print(x, display=False) == display z = x.add(y) assert torch_utils.chain_print(z, display=False) == 'Variable > _PlusIsMinusTensor > ' \ '_LocalTensor\n - FloatTensor >' \ ' _PlusIsMinusTensor > _LocalTensor' # cut chain for the equality check z.data.child = z.data.child.child assert torch.equal(z.data, torch.FloatTensor([2, 2])) z = torch.add(x, y) # cut chain for the equality check z.data.child = z.data.child.child assert torch.equal(z.data, torch.FloatTensor([2, 2]))
def test_plus_is_minus_tensor_local(self): x = torch.FloatTensor([5, 6]) y = torch.FloatTensor([3, 4]) x = sy._PlusIsMinusTensor().on(x) y = sy._PlusIsMinusTensor().on(y) assert torch_utils.chain_print( x, display=False) == 'FloatTensor > _PlusIsMinusTensor > _LocalTensor' z = x.add(y) assert torch_utils.chain_print( z, display=False) == 'FloatTensor > _PlusIsMinusTensor > _LocalTensor' # cut chain for the equality check z.child = z.child.child assert torch.equal(z, torch.FloatTensor([2, 2])) z = torch.add(x, y) # cut chain for the equality check z.child = z.child.child assert torch.equal(z, torch.FloatTensor([2, 2]))
def test_plus_is_minus_tensor_remote(self): x = torch.FloatTensor([5, 6]) y = torch.FloatTensor([3, 4]) x = sy._PlusIsMinusTensor().on(x) y = sy._PlusIsMinusTensor().on(y) id1 = random.randint(0, 10e10) id2 = random.randint(0, 10e10) x.send(bob, ptr_id=id1) y.send(bob, ptr_id=id2) z = x.add(y) assert torch_utils.chain_print( z, display=False) == 'FloatTensor > _PointerTensor' # Check chain on remote ptr_id = z.child.id_at_location assert torch_utils.chain_print( bob._objects[ptr_id].parent, display=False) == 'FloatTensor > _PlusIsMinusTensor > _LocalTensor' z.get() assert torch_utils.chain_print( z, display=False) == 'FloatTensor > _PlusIsMinusTensor > _LocalTensor' # cut chain for the equality check z.child = z.child.child assert torch.equal(z, torch.FloatTensor([2, 2]))
def test_plus_is_minus_variable_remote(self): x = sy.Variable(torch.FloatTensor([5, 6])) y = sy.Variable(torch.FloatTensor([3, 4])) x = sy._PlusIsMinusTensor().on(x) y = sy._PlusIsMinusTensor().on(y) id1 = random.randint(0, 10e10) id2 = random.randint(0, 10e10) id11 = random.randint(0, 10e10) id21 = random.randint(0, 10e10) x.send(bob, new_id=id1, new_data_id=id11) y.send(bob, new_id=id2, new_data_id=id21) z = x.add(y) assert torch_utils.chain_print(z, display=False) == 'Variable > _PointerTensor\n' \ ' - FloatTensor > _PointerTensor\n' \ ' - - Variable > _PointerTensor\n' \ ' - FloatTensor > _PointerTensor' assert bob._objects[z.id_at_location].owner.id == 'bob' assert bob._objects[z.data.id_at_location].owner.id == 'bob' # Check chain on remote ptr_id = x.child.id_at_location display = 'Variable > _PlusIsMinusTensor > _LocalTensor\n' \ ' - FloatTensor > _PlusIsMinusTensor > _LocalTensor\n' \ ' - - Variable > _PlusIsMinusTensor > _LocalTensor\n' \ ' - FloatTensor > _PlusIsMinusTensor > _LocalTensor' assert torch_utils.chain_print(bob._objects[ptr_id].parent, display=False) == display # Check chain on remote # TODO For now we don't reconstruct the grad chain one non-leaf variable (in our case a leaf # variable is a variable that we sent), because we don't care about their gradient. But if we do, # then this is a TODO! ptr_id = z.child.id_at_location display = 'Variable > _PlusIsMinusTensor > _LocalTensor\n' \ ' - FloatTensor > _PlusIsMinusTensor > _LocalTensor\n' \ ' - - Variable > _LocalTensor\n' \ ' - FloatTensor > _LocalTensor' assert torch_utils.chain_print(bob._objects[ptr_id].parent, display=False) == display z.get() display = 'Variable > _PlusIsMinusTensor > _LocalTensor\n' \ ' - FloatTensor > _PlusIsMinusTensor > _LocalTensor\n' \ ' - - Variable > _LocalTensor\n' \ ' - FloatTensor > _LocalTensor' assert torch_utils.chain_print(z, display=False) == display # cut chain for the equality check z.data.child = z.data.child.child assert torch.equal(z.data, torch.FloatTensor([2, 2]))