예제 #1
0
    def div_(self, y):
        """Divide two tensors element-wise"""
        # TODO: Add test coverage for this code path (next 4 lines)
        if isinstance(y, float) and int(y) == y:
            y = int(y)
        if is_float_tensor(y) and y.frac().eq(0).all():
            y = y.long()

        if isinstance(y, int) or is_int_tensor(y):
            # Truncate protocol for dividing by public integers:
            if comm.get().get_world_size() > 2:
                wraps = self.wraps()
                self.share /= y
                # NOTE: The multiplication here must be split into two parts
                # to avoid long out-of-bounds when y <= 2 since (2 ** 63) is
                # larger than the largest long integer.
                self -= wraps * 4 * (int(2 ** 62) // y)
            else:
                self.share /= y
            return self

        # Otherwise multiply by reciprocal
        if isinstance(y, float):
            y = torch.FloatTensor([y])

        assert is_float_tensor(y), "Unsupported type for div_: %s" % type(y)
        return self.mul_(y.reciprocal())
예제 #2
0
    def _check(self,
               encrypted_tensor,
               reference,
               msg,
               dst=None,
               tolerance=None):
        if tolerance is None:
            tolerance = getattr(self, "default_tolerance", 0.05)
        tensor = encrypted_tensor.get_plain_text(dst=dst)
        if dst is not None and dst != self.rank:
            self.assertIsNone(tensor)
            return

        # Check sizes match
        self.assertTrue(tensor.size() == reference.size(), msg)

        self.assertTrue(is_float_tensor(reference),
                        "reference must be a float")
        diff = (tensor - reference).abs_()
        norm_diff = diff.div(tensor.abs() + reference.abs()).abs_()
        test_passed = norm_diff.le(tolerance) + diff.le(tolerance * 0.1)
        test_passed = test_passed.gt(0).all().item() == 1
        if not test_passed:
            logging.info(msg)
            logging.info("Result %s" % tensor)
            logging.info("Result - Reference = %s" % (tensor - reference))
        self.assertTrue(test_passed, msg=msg)
예제 #3
0
    def _set_grad_to_zero(self, args, make_private=False):
        """Sets gradients for args to zero

        Args:
            args (list of torch.tensors): contains arguments
            make_private (bool): encrypt args using CrypTensor
        """
        args_zero_grad = []

        for arg in args:
            if is_float_tensor(arg) and make_private:
                arg = crypten.cryptensor(arg, requires_grad=True)
            elif is_float_tensor(arg):
                arg.requires_grad = True
                arg.grad = None

            args_zero_grad.append(arg)

        return args_zero_grad
예제 #4
0
    def div_(self, y):
        """Divide two tensors element-wise"""
        # TODO: Add test coverage for this code path (next 4 lines)
        if isinstance(y, float) and int(y) == y:
            y = int(y)
        if is_float_tensor(y) and y.frac().eq(0).all():
            y = y.long()

        if isinstance(y, int) or is_int_tensor(y):
            if debug_mode():
                tolerance = 1.0
                tensor = self.get_plain_text()

            # Truncate protocol for dividing by public integers:
            if comm.get().get_world_size() > 2:
                wraps = self.wraps()
                self.share //= y
                # NOTE: The multiplication here must be split into two parts
                # to avoid long out-of-bounds when y <= 2 since (2 ** 63) is
                # larger than the largest long integer.
                self -= wraps * 4 * (int(2 ** 62) // y)
            else:
                self.share //= y

            if debug_mode():
                if not torch.lt(
                    torch.abs(self.get_plain_text() * y - tensor), tolerance
                ).all():
                    raise ValueError("Final result of division is incorrect.")

            return self

        # Otherwise multiply by reciprocal
        if isinstance(y, float):
            y = torch.tensor([y], dtype=torch.float, device=self.device)

        assert is_float_tensor(y), "Unsupported type for div_: %s" % type(y)
        return self.mul_(y.reciprocal())
예제 #5
0
    def test_encrypt(self):
        for size, tensor in zip(self.sizes, self.tensors):
            if is_float_tensor(tensor):
                tensor_type = ArithmeticSharedTensor
            else:
                tensor_type = BinarySharedTensor
            tensor_name = tensor_type.__name__
            with self.benchmark(
                tensor_type=tensor_name, size=size, is_float=self.is_float(tensor)
            ) as bench:
                for _ in bench.iters:
                    encrypted_tensor = tensor_type(tensor)

            self.assertTrue(encrypted_tensor is not None)
예제 #6
0
    def div_(self, y):
        """Divide two tensors element-wise"""
        # TODO: Add test coverage for this code path (next 4 lines)
        if isinstance(y, float) and int(y) == y:
            y = int(y)
        if is_float_tensor(y) and y.frac().eq(0).all():
            y = y.long()

        if isinstance(y, int) or is_int_tensor(y):
            validate = cfg.debug.validation_mode

            if validate:
                tolerance = 1.0
                tensor = self.get_plain_text()

            # Truncate protocol for dividing by public integers:
            if comm.get().get_world_size() > 2:
                protocol = globals()[cfg.mpc.protocol]
                protocol.truncate(self, y)
            else:
                self.share = self.share.div_(y, rounding_mode="trunc")

            # Validate
            if validate:
                if not torch.lt(torch.abs(self.get_plain_text() * y - tensor),
                                tolerance).all():
                    raise ValueError("Final result of division is incorrect.")

            return self

        # Otherwise multiply by reciprocal
        if isinstance(y, float):
            y = torch.tensor([y], dtype=torch.float, device=self.device)

        assert is_float_tensor(y), "Unsupported type for div_: %s" % type(y)
        return self.mul_(y.reciprocal())
예제 #7
0
 def test_decrypt(self):
     for tensor_type in [ArithmeticSharedTensor, BinarySharedTensor]:
         tensor_name = tensor_type.__name__
         tensors = self.tensors
         if tensor_type == ArithmeticSharedTensor:
             tensors = [t for t in tensors if is_float_tensor(t)]
         else:
             tensors = [t for t in tensors if is_int_tensor(t)]
         encrypted_tensors = [tensor_type(tensor) for tensor in tensors]
         data = zip(self.sizes, tensors, encrypted_tensors)
         for size, tensor, encrypted_tensor in data:
             with self.benchmark(
                 tensor_type=tensor_name, size=size, float=self.is_float(tensor)
             ) as bench:
                 for _ in bench.iters:
                     tensor = encrypted_tensor.get_plain_text()
예제 #8
0
    def _check(self, encrypted_tensor, reference, msg, tolerance=None):
        if tolerance is None:
            tolerance = getattr(self, "default_tolerance", 0.05)
        tensor = encrypted_tensor.get_plain_text()

        # Check sizes match
        self.assertTrue(tensor.size() == reference.size(), msg)

        if is_float_tensor(reference):
            diff = (tensor - reference).abs_()
            norm_diff = diff.div(tensor.abs() + reference.abs()).abs_()
            test_passed = norm_diff.le(tolerance) + diff.le(tolerance * 0.2)
            test_passed = test_passed.gt(0).all().item() == 1
        else:
            test_passed = (tensor == reference).all().item() == 1
        if not test_passed:
            logging.info(msg)
            logging.info("Result = %s;\nreference = %s" % (tensor, reference))
        self.assertTrue(test_passed, msg=msg)
예제 #9
0
    def _check_forward_backward(self,
                                fn_name,
                                input_tensor,
                                *args,
                                msg=None,
                                **kwargs):
        if msg is None:
            msg = f"{fn_name} grad_fn incorrect"

        for requires_grad in [True]:
            # Setup input
            input = input_tensor.clone()
            input.requires_grad = requires_grad
            input_encr = AutogradCrypTensor(crypten.cryptensor(input),
                                            requires_grad=requires_grad)

            for private in [False, True]:
                input.grad = None
                input_encr.grad = None

                # Setup args
                args_encr = list(args)
                for i, arg in enumerate(args):
                    if private and is_float_tensor(arg):
                        args_encr[i] = AutogradCrypTensor(
                            crypten.cryptensor(arg),
                            requires_grad=requires_grad)
                        args_encr[i].grad = None  # zero grad
                    if is_float_tensor(arg):
                        args[i].requires_grad = requires_grad
                        args[i].grad = None  # zero grad

                # Check forward pass
                if hasattr(input, fn_name):
                    reference = getattr(input, fn_name)(*args, **kwargs)
                elif hasattr(F, fn_name):
                    reference = getattr(F, fn_name)(input, *args, **kwargs)
                elif fn_name == "square":
                    reference = input.pow(2)
                else:
                    raise ValueError("unknown PyTorch function: %s" % fn_name)

                encrypted_out = getattr(input_encr, fn_name)(*args_encr,
                                                             **kwargs)

                # Remove argmax output from max / min
                if isinstance(encrypted_out, (list, tuple)):
                    reference = reference[0]
                    encrypted_out = encrypted_out[0]

                self._check(encrypted_out, reference, msg + " in forward")

                # Check backward pass
                grad_output = get_random_test_tensor(max_value=2,
                                                     size=reference.size(),
                                                     is_float=True)
                grad_output_encr = crypten.cryptensor(grad_output)

                # Do not check backward if pytorch backward fails
                try:
                    reference.backward(grad_output)
                except RuntimeError:
                    logging.info("skipped")
                    continue
                encrypted_out.backward(grad_output_encr)

                self._check(input_encr.grad, input.grad, msg + " in backward")
                for i, arg_encr in enumerate(args_encr):
                    if crypten.is_encrypted_tensor(arg_encr):
                        self._check(arg_encr.grad, args[i].grad,
                                    msg + " in backward args")