Ejemplo n.º 1
0
    def test_squeeze(self):
        tensor = get_random_test_tensor(is_float=True)

        for dim in [0, 1, 2]:
            # Test unsqueeze
            reference = tensor.unsqueeze(dim)

            encrypted = ArithmeticSharedTensor(tensor)
            with self.benchmark(type="unsqueeze", dim=dim) as bench:
                for _ in bench.iters:
                    encrypted_out = encrypted.unsqueeze(dim)
            self._check(encrypted_out, reference, "unsqueeze failed")

            # Test squeeze
            encrypted = ArithmeticSharedTensor(tensor.unsqueeze(0))
            with self.benchmark(type="squeeze", dim=dim) as bench:
                for _ in bench.iters:
                    encrypted_out = encrypted.squeeze()
            self._check(encrypted_out, reference.squeeze(), "squeeze failed")

            # Check that the encrypted_out and encrypted point to the same
            # thing.
            encrypted_out[0:2] = torch.FloatTensor([0, 1])
            ref = encrypted.squeeze().get_plain_text()
            self._check(encrypted_out, ref, "squeeze failed")
Ejemplo n.º 2
0
    def test_squeeze(self):
        tensor = get_random_test_tensor(is_float=True)

        for dim in [0, 1, 2]:
            # Test unsqueeze
            reference = tensor.unsqueeze(dim)

            encrypted = ArithmeticSharedTensor(tensor)
            encrypted_out = encrypted.unsqueeze(dim)
            self._check(encrypted_out, reference, "unsqueeze failed")

            # Test squeeze
            encrypted = ArithmeticSharedTensor(tensor.unsqueeze(0))
            encrypted_out = encrypted.squeeze()
            self._check(encrypted_out, reference.squeeze(), "squeeze failed")

            # Check that the encrypted_out and encrypted point to the same
            # thing.
            encrypted_out[0:2] = torch.tensor([0.0, 1.0], dtype=torch.float)
            ref = encrypted.squeeze().get_plain_text()
            self._check(encrypted_out, ref, "squeeze failed")