def test_squeeze(self): tensor = get_random_test_tensor(is_float=True) for dim in [0, 1, 2]: # Test unsqueeze reference = tensor.unsqueeze(dim) encrypted = ArithmeticSharedTensor(tensor) with self.benchmark(type="unsqueeze", dim=dim) as bench: for _ in bench.iters: encrypted_out = encrypted.unsqueeze(dim) self._check(encrypted_out, reference, "unsqueeze failed") # Test squeeze encrypted = ArithmeticSharedTensor(tensor.unsqueeze(0)) with self.benchmark(type="squeeze", dim=dim) as bench: for _ in bench.iters: encrypted_out = encrypted.squeeze() self._check(encrypted_out, reference.squeeze(), "squeeze failed") # Check that the encrypted_out and encrypted point to the same # thing. encrypted_out[0:2] = torch.FloatTensor([0, 1]) ref = encrypted.squeeze().get_plain_text() self._check(encrypted_out, ref, "squeeze failed")
def test_squeeze(self): tensor = get_random_test_tensor(is_float=True) for dim in [0, 1, 2]: # Test unsqueeze reference = tensor.unsqueeze(dim) encrypted = ArithmeticSharedTensor(tensor) encrypted_out = encrypted.unsqueeze(dim) self._check(encrypted_out, reference, "unsqueeze failed") # Test squeeze encrypted = ArithmeticSharedTensor(tensor.unsqueeze(0)) encrypted_out = encrypted.squeeze() self._check(encrypted_out, reference.squeeze(), "squeeze failed") # Check that the encrypted_out and encrypted point to the same # thing. encrypted_out[0:2] = torch.tensor([0.0, 1.0], dtype=torch.float) ref = encrypted.squeeze().get_plain_text() self._check(encrypted_out, ref, "squeeze failed")