コード例 #1
0
ファイル: __init__.py プロジェクト: yashk2000/PySyft
def load(tag: str, src: int, **kwargs):
    """TODO: Think of a method to keep the serialized models at the workers that are part of the
    computation in such a way that the worker that started the computation do not know what
    model architecture is used

    if tag.startswith("crypten_model"):
        worker = get_worker_from_rank(src)
        results = worker.search(tag)
        assert len(results) == 1

        model = results[0]
        assert isinstance(model, OnnxModel)

        return utils.onnx_to_crypten(model.serialized_model)
    """

    if src == comm.get().get_rank():
        if CID is None:
            raise RuntimeError("CrypTen computation id is not set.")

        worker = get_worker_from_rank(src)
        results = worker.search(tag)

        # Make sure there is only one result
        assert len(results) == 1

        result = crypten.load_from_party(preloaded=results[0],
                                         src=src,
                                         **kwargs)

    else:
        result = crypten.load_from_party(preloaded=-1, src=src, **kwargs)

    return result
コード例 #2
0
def jointly_train():
    encrypt_digits()
    alice_images_enc = crypten.load_from_party("/tmp/data/alice_images.pth",
                                               src=ALICE)
    bob_labels_enc = crypten.load_from_party("/tmp/data/bob_labels.pth",
                                             src=BOB)

    model = LogisticRegression().encrypt()
    model = train_model(model, alice_images_enc, bob_labels_enc)
コード例 #3
0
def load(tag: str, src: int, **kwargs):
    if src == comm.get().get_rank():
        worker = syft.local_worker.get_worker_from_rank(src)
        results = worker.search(tag)

        # Make sure there is only one result
        assert len(results) == 1

        result = results[0]
        result = crypten.load_from_party(preloaded=result, src=src, **kwargs)

    else:
        result = crypten.load_from_party(preloaded=-1, src=src, **kwargs)

    return result
コード例 #4
0
    def test_plaintext_save_load_module_from_party(self):
        """Test that crypten.save_from_party and crypten.load_from_party
        properly save and load plaintext modules"""
        import tempfile

        comm = crypten.communicator
        for model_type in [TestModule, NestedTestModule]:
            # Create models with different parameter values on each rank
            rank = comm.get().get_rank()

            test_model = model_type(200, 10)
            test_model.set_all_parameters(rank)
            serial.register_safe_class(model_type)

            filename = tempfile.NamedTemporaryFile(delete=True).name
            for src in range(comm.get().get_world_size()):
                crypten.save_from_party(test_model, filename, src=src)

                result = crypten.load_from_party(filename, src=src)
                if src == rank:
                    for param in result.parameters(recurse=True):
                        self.assertTrue(
                            param.eq(rank).all().item(), "Model load failed"
                        )
                self.assertEqual(result.src, src)
コード例 #5
0
    def test_save_load(self):
        """Test that crypten.save and crypten.load properly save and load tensors"""
        import tempfile
        import numpy as np

        def custom_load_function(f):
            np_arr = np.load(f)
            tensor = torch.from_numpy(np_arr)
            return tensor

        def custom_save_function(obj, f):
            np_arr = obj.numpy()
            np.save(f, np_arr)

        comm = crypten.communicator
        filename = tempfile.NamedTemporaryFile(delete=True).name
        all_save_fns = [torch.save, custom_save_function]
        all_load_fns = [torch.load, custom_load_function]
        all_file_completions = [".pth", ".npy"]
        all_test_load_fns = [torch.load, np.load]
        for dimensions in range(1, 5):
            # Create tensors with different sizes on each rank
            size = [self.rank + 1] * dimensions
            size = tuple(size)
            tensor = torch.randn(size=size)

            for i, save_closure in enumerate(all_save_fns):
                load_closure = all_load_fns[i]
                test_load_fn = all_test_load_fns[i]
                complete_file = filename + all_file_completions[i]
                for src in range(comm.get().get_world_size()):
                    crypten.save_from_party(tensor,
                                            complete_file,
                                            src=src,
                                            save_closure=save_closure)

                    # the following line will throw an error if an object saved with
                    # torch.save is attempted to be loaded with np.load
                    if self.rank == src:
                        test_load_fn(complete_file)

                    encrypted_load = crypten.load_from_party(
                        complete_file, src=src, load_closure=load_closure)

                    reference_size = tuple([src + 1] * dimensions)
                    self.assertEqual(encrypted_load.size(), reference_size)

                    size_out = [src + 1] * dimensions
                    reference = (tensor if self.rank == src else torch.empty(
                        size=size_out))
                    comm.get().broadcast(reference, src=src)
                    self._check(encrypted_load, reference,
                                "crypten.load() failed")

                    # test for invalid load_closure
                    with self.assertRaises(TypeError):
                        crypten.load_from_party(complete_file,
                                                src=src,
                                                load_closure=(lambda f: None))

                    # test pre-loaded
                    encrypted_preloaded = crypten.load_from_party(
                        src=src, preloaded=tensor)
                    self._check(
                        encrypted_preloaded,
                        reference,
                        "crypten.load() failed using preloaded",
                    )