示例#1
0
def test_sandbox():
    sy.create_sandbox(globals(), download_data=True)

    assert alice == alice  # noqa: F821
    assert isinstance(alice, sy.VirtualWorker)  # noqa: F821
    assert andy == andy  # noqa: F821
    assert isinstance(andy, sy.VirtualWorker)  # noqa: F821
    assert bob == bob  # noqa: F821
    assert isinstance(bob, sy.VirtualWorker)  # noqa: F821
    assert jason == jason  # noqa: F821
    assert isinstance(jason, sy.VirtualWorker)  # noqa: F821
    assert jon == jon  # noqa: F821
    assert isinstance(jon, sy.VirtualWorker)  # noqa: F821
    assert theo == theo  # noqa: F821
    assert isinstance(theo, sy.VirtualWorker)  # noqa: F821

    assert hook == hook  # noqa: F821
    assert isinstance(hook, TorchHook)  # noqa: F821

    assert grid == grid  # noqa: F821
    assert isinstance(grid, PrivateGridNetwork)  # noqa: F821

    assert workers == [bob, theo, jason, alice, andy, jon]  # noqa: F821

    assert bob.search(["#boston"])  # noqa: F821
    assert bob.search(["#diabetes"])  # noqa: F821
    assert bob.search(["#breast_cancer_dataset"])  # noqa: F821
    assert bob.search(["#digits_dataset"])  # noqa: F821
    assert bob.search(["#iris_dataset"])  # noqa: F821
    assert bob.search(["#wine_dataset"])  # noqa: F821
    assert bob.search(["#linnerrud_dataset"])  # noqa: F821

    try:
        from tensorflow.keras import datasets  # noqa: F401

        assert bob.search(["#fashion_mnist"])  # noqa: F821
        assert bob.search(["#cifar10"])  # noqa: F821
    except ImportError:  # pragma: no cover
        pass  # pragma: no cover
示例#2
0
def test_sandbox():
    sy.create_sandbox(globals(), download_data=False)

    assert alice == alice  # noqa: F821
    assert isinstance(alice, sy.VirtualWorker)  # noqa: F821
    assert andy == andy  # noqa: F821
    assert isinstance(andy, sy.VirtualWorker)  # noqa: F821
    assert bob == bob  # noqa: F821
    assert isinstance(bob, sy.VirtualWorker)  # noqa: F821
    assert jason == jason  # noqa: F821
    assert isinstance(jason, sy.VirtualWorker)  # noqa: F821
    assert jon == jon  # noqa: F821
    assert isinstance(jon, sy.VirtualWorker)  # noqa: F821
    assert theo == theo  # noqa: F821
    assert isinstance(theo, sy.VirtualWorker)  # noqa: F821

    assert hook == hook  # noqa: F821
    assert isinstance(hook, TorchHook)  # noqa: F821

    assert grid == grid  # noqa: F821
    assert isinstance(grid, PrivateGridNetwork)  # noqa: F821

    assert workers == [bob, theo, jason, alice, andy, jon]  # noqa: F821
示例#3
0
def test_sandbox():
    sy.create_sandbox(globals(), download_data=False)

    # check to make sure global variable gets set for alice
    assert alice == alice  # noqa: F821
    assert isinstance(alice, sy.VirtualWorker)  # noqa: F821
import syft as sy
import torch as th

sy.create_sandbox(globals(), verbose=False)

device = th.device("cuda")
"""Then search for a dataset"""

boston_data = grid.search("#boston", "#data")
boston_target = grid.search("#boston", "#target")
"""We load a model and an optimizer"""

n_features = boston_data['alice'][0].shape[1]
n_targets = 1

model = th.nn.Linear(n_features, n_targets).to(device)
"""Here we cast the data fetched in a `FederatedDataset`. See the workers which hold part of the data."""
print("A total of {} workers".format(len(boston_data.keys())))
# Cast the result in BaseDatasets
datasets = []
for worker in boston_data.keys():
    dataset = sy.BaseDataset(boston_data[worker][0], boston_target[worker][0])
    datasets.append(dataset)

# Build the FederatedDataset object
dataset = sy.FederatedDataset(datasets)
print(dataset.workers)
optimizers = {}
for worker in dataset.workers:
    optimizers[worker] = th.optim.Adam(params=model.parameters(), lr=1e-2)
    # optimizers[worker] = th.optim.SGD(model.parameters(), lr=0.05)
示例#5
0
文件: server.py 项目: zwvews/PySyft
from flask import Flask
from flask import request
import torch as th
import syft as sy

sy.create_sandbox(globals())

app = Flask(__name__)

# Iniitalize A Toy Model
model = th.zeros([2, 1])
ptr = None


@app.route("/get_model", methods=["GET"])
def get_model():
    global model
    global ptr
    ptr = model.create_pointer()
    return model.ser()


@app.route("/send_data", methods=["POST"])
def send_data():
    global ptr
    ptr = sy.serde.deserialize(request.data)

    return model.ser()
示例#6
0
        "_protocol": str(protocolPb),
        "jason": str(simplifiedJasonPlan),
        "_jason": str(jasonPlanPb),
        "andy": str(simplifiedAndyPlan),
        "_andy": str(andyPlanPb),
    }


def serializeToBase64Pb(worker, obj):
    pb = protobuf.serde._bufferize(worker, obj)
    bin = pb.SerializeToString()
    return pb, base64.b64encode(bin).decode('utf-8')

# Replace ID_PROVIDER so we have same nice IDs every time
sy.ID_PROVIDER = [ int(1e10) + i for i in reversed(range(100)) ]

sy.create_sandbox(globals(), download_data=False)
hook.local_worker.is_client_worker = False
hook.local_worker.framework = None
me = hook.local_worker

first = generateThreeWayProtocol(me)
second = generateTwoWayProtocol(me)

data = { "three-way": first, "two-way": second }

print(data)
print("\n----------\n\nSeed file created successfully!")

with open('./seed/data.json', 'w', encoding='utf-8') as f:
    json.dump(data, f)