def test_connect_close(hook, start_proc): kwargs = {"id": "fed", "host": "localhost", "port": 8763, "hook": hook} process_remote_worker = start_proc(WebsocketServerWorker, **kwargs) time.sleep(0.1) kwargs = {"id": "fed", "host": "localhost", "port": 8763, "hook": hook} local_worker = WebsocketClientWorker(**kwargs) x = torch.tensor([1, 2, 3]) x_ptr = x.send(local_worker) assert local_worker.objects_count_remote() == 1 local_worker.close() time.sleep(0.1) local_worker.connect() assert local_worker.objects_count_remote() == 1 x_val = x_ptr.get() assert (x_val == x).all() local_worker.ws.shutdown() time.sleep(0.1) process_remote_worker.terminate()
optimizer=optimizer, batch_size=batch_size, optimizer_args=optimizer_args, epochs=5, shuffle=shuffle) arw = {"host":"10.0.0.1","hook":hook} h1 = WebsocketClientWorker(id="h1",port=8778,**arw) train_config.send(h1) message = h1.create_message_execute_command(command_name="start_monitoring",command_owner="self") serialized_message = syft.serde.serialize(message) h1._recv_msg(serialized_message) time.sleep(3) for epoch in range(10): time.sleep(3) loss = h1.fit(dataset_key="train") # ask alice to train using "xor" dataset print("-" * 50) print("Iteration %s: h1's loss: %s" % (epoch, loss)) message = h1.create_message_execute_command(command_name="stop_monitoring",command_owner="self") serialized_message = syft.serde.serialize(message) h1._recv_msg(serialized_message) new_model = train_config.model_ptr.get() h1.close()