Ejemplo n.º 1
0
                              optimizer=optimizer,
                              batch_size=batch_size,
                              optimizer_args=optimizer_args,
                              epochs=5,
                              shuffle=shuffle)

arw = {"host":"10.0.0.1","hook":hook}
h1 = WebsocketClientWorker(id="h1",port=8778,**arw)
train_config.send(h1)


message = h1.create_message_execute_command(command_name="start_monitoring",command_owner="self")
serialized_message = syft.serde.serialize(message)
h1._recv_msg(serialized_message)


time.sleep(3)
for epoch in range(10):
    time.sleep(3)
    loss = h1.fit(dataset_key="train")  # ask alice to train using "xor" dataset
    print("-" * 50)
    print("Iteration %s: h1's loss: %s" % (epoch, loss))

message = h1.create_message_execute_command(command_name="stop_monitoring",command_owner="self")
serialized_message = syft.serde.serialize(message)
h1._recv_msg(serialized_message)

new_model = train_config.model_ptr.get()
h1.close()

Ejemplo n.º 2
0
def test_train_config_with_jit_trace_sync(hook,
                                          start_proc):  # pragma: no cover
    kwargs = {
        "id": "sync_fit",
        "host": "localhost",
        "port": 9000,
        "hook": hook
    }
    # data = torch.tensor([[-1, 2.0], [0, 1.1], [-1, 2.1], [0, 1.2]], requires_grad=True)
    # target = torch.tensor([[1], [0], [1], [0]])

    data, target = utils.create_gaussian_mixture_toy_data(100)

    dataset = sy.BaseDataset(data, target)

    dataset_key = "gaussian_mixture"
    process_remote_worker = start_proc(WebsocketServerWorker,
                                       dataset=(dataset, dataset_key),
                                       **kwargs)

    time.sleep(0.1)

    local_worker = WebsocketClientWorker(**kwargs)

    @hook.torch.jit.script
    def loss_fn(pred, target):
        return ((target.view(pred.shape).float() - pred.float())**2).mean()

    class Net(torch.nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 3)
            self.fc2 = nn.Linear(3, 2)
            self.fc3 = nn.Linear(2, 1)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x

    model_untraced = Net()

    model = torch.jit.trace(model_untraced, data)

    pred = model(data)
    loss_before = loss_fn(pred=pred, target=target)

    # Create and send train config
    train_config = sy.TrainConfig(model=model,
                                  loss_fn=loss_fn,
                                  batch_size=2,
                                  epochs=1)
    train_config.send(local_worker)

    for epoch in range(5):
        loss = local_worker.fit(dataset_key=dataset_key)
        if PRINT_IN_UNITTESTS:  # pragma: no cover
            print("-" * 50)
            print("Iteration %s: alice's loss: %s" % (epoch, loss))

    new_model = train_config.model_ptr.get()

    # assert that the new model has updated (modified) parameters
    assert not (model.fc1._parameters["weight"]
                == new_model.obj.fc1._parameters["weight"]).all()
    assert not (model.fc2._parameters["weight"]
                == new_model.obj.fc2._parameters["weight"]).all()
    assert not (model.fc3._parameters["weight"]
                == new_model.obj.fc3._parameters["weight"]).all()
    assert not (model.fc1._parameters["bias"]
                == new_model.obj.fc1._parameters["bias"]).all()
    assert not (model.fc2._parameters["bias"]
                == new_model.obj.fc2._parameters["bias"]).all()
    assert not (model.fc3._parameters["bias"]
                == new_model.obj.fc3._parameters["bias"]).all()

    new_model.obj.eval()
    pred = new_model.obj(data)
    loss_after = loss_fn(pred=pred, target=target)

    if PRINT_IN_UNITTESTS:  # pragma: no cover
        print("Loss before training: {}".format(loss_before))
        print("Loss after training: {}".format(loss_after))

    local_worker.ws.shutdown()
    del local_worker

    time.sleep(0.1)
    process_remote_worker.terminate()

    assert loss_after < loss_before