示例#1
0
def test_federated_dataloader(workers):
    bob = workers["bob"]
    alice = workers["alice"]
    datasets = [
        fl.BaseDataset(th.tensor([1, 2]), th.tensor([1, 2])).send(bob),
        fl.BaseDataset(th.tensor([3, 4, 5, 6]), th.tensor([3, 4, 5,
                                                           6])).send(alice),
    ]
    fed_dataset = sy.FederatedDataset(datasets)

    fdataloader = sy.FederatedDataLoader(fed_dataset, batch_size=2)
    counter = 0
    for batch_idx, (data, target) in enumerate(fdataloader):
        counter += 1

    assert counter == len(fdataloader), f"{counter} == {len(fdataloader)}"

    fdataloader = sy.FederatedDataLoader(fed_dataset,
                                         batch_size=2,
                                         drop_last=True)
    counter = 0
    for batch_idx, (data, target) in enumerate(fdataloader):
        counter += 1

    assert counter == len(fdataloader), f"{counter} == {len(fdataloader)}"
示例#2
0
def test_federated_dataloader_shuffle(workers):
    bob = workers["bob"]
    alice = workers["alice"]
    datasets = [
        fl.BaseDataset(th.tensor([1, 2]), th.tensor([1, 2])).send(bob),
        fl.BaseDataset(th.tensor([3, 4, 5, 6]), th.tensor([3, 4, 5,
                                                           6])).send(alice),
    ]
    fed_dataset = sy.FederatedDataset(datasets)

    fdataloader = sy.FederatedDataLoader(fed_dataset,
                                         batch_size=2,
                                         shuffle=True)
    for epoch in range(3):
        counter = 0
        for batch_idx, (data, target) in enumerate(fdataloader):
            if counter < 1:  # one batch for bob, two batches for alice (batch_size == 2)
                assert (
                    data.location.id == "bob"
                ), f"id should be bob, counter = {counter}, epoch = {epoch}"
            else:
                assert (
                    data.location.id == "alice"
                ), f"id should be alice, counter = {counter}, epoch = {epoch}"
            counter += 1
        assert counter == len(fdataloader), f"{counter} == {len(fdataloader)}"

    num_iterators = 2
    fdataloader = sy.FederatedDataLoader(fed_dataset,
                                         batch_size=2,
                                         num_iterators=num_iterators,
                                         shuffle=True)
    assert (fdataloader.num_iterators == num_iterators -
            1), f"{fdataloader.num_iterators} == {num_iterators - 1}"
示例#3
0
def test_extract_batches_per_worker(workers):
    bob = workers["bob"]
    alice = workers["alice"]

    datasets = [
        fl.BaseDataset(th.tensor([1, 2]), th.tensor([1, 2])).send(bob),
        fl.BaseDataset(th.tensor([3, 4, 5, 6]), th.tensor([3, 4, 5, 6])).send(alice),
    ]
    fed_dataset = sy.FederatedDataset(datasets)

    fdataloader = sy.FederatedDataLoader(fed_dataset, batch_size=2, shuffle=True)

    batches = utils.extract_batches_per_worker(fdataloader)

    assert len(batches.keys()) == len(
        datasets
    ), "each worker should appear as key in the batches dictionary"
示例#4
0
def test_federated_dataloader_num_iterators(workers):
    bob = workers["bob"]
    alice = workers["alice"]
    james = workers["james"]
    datasets = [
        fl.BaseDataset(th.tensor([1, 2]), th.tensor([1, 2])).send(bob),
        fl.BaseDataset(th.tensor([3, 4, 5, 6]), th.tensor([3, 4, 5,
                                                           6])).send(alice),
        fl.BaseDataset(th.tensor([7, 8, 9, 10]), th.tensor([7, 8, 9,
                                                            10])).send(james),
    ]

    fed_dataset = sy.FederatedDataset(datasets)
    num_iterators = len(datasets)
    fdataloader = sy.FederatedDataLoader(fed_dataset,
                                         batch_size=2,
                                         num_iterators=num_iterators,
                                         shuffle=True)
    assert (fdataloader.num_iterators == num_iterators -
            1), f"{fdataloader.num_iterators} == {num_iterators - 1}"
    counter = 0
    for batch_idx, batches in enumerate(fdataloader):
        assert (len(batches.keys()) == num_iterators -
                1), f"len(batches.keys()) == {num_iterators} - 1"
        if batch_idx < 1:
            data_bob, target_bob = batches[bob]
            assert data_bob.location.id == "bob", "id should be bob, batch_idx = {0}".format(
                batch_idx)
        else:  # bob is replaced by james
            data_james, target_james = batches[james]
            assert data_james.location.id == "james", "id should be james, batch_idx = {0}".format(
                batch_idx)
        if batch_idx < 2:
            data_alice, target_alice = batches[alice]
            assert data_alice.location.id == "alice", "id should be alice, batch_idx = {0}".format(
                batch_idx)
        counter += 1
    epochs = num_iterators - 1
    assert counter * (num_iterators - 1) == epochs * len(
        fdataloader), " == epochs * len(fdataloader)"
示例#5
0
def test_federated_dataloader_one_worker(workers):
    bob = workers["bob"]

    datasets = [
        fl.BaseDataset(th.tensor([3, 4, 5, 6]), th.tensor([3, 4, 5,
                                                           6])).send(bob)
    ]

    fed_dataset = sy.FederatedDataset(datasets)
    num_iterators = len(datasets)
    fdataloader = sy.FederatedDataLoader(fed_dataset,
                                         batch_size=2,
                                         shuffle=True)
    assert fdataloader.num_iterators == 1, f"{fdataloader.num_iterators} == {1}"
示例#6
0
def test_federated_dataloader_iter_per_worker(workers):
    bob = workers["bob"]
    alice = workers["alice"]
    james = workers["james"]
    datasets = [
        fl.BaseDataset(th.tensor([1, 2]), th.tensor([1, 2])).send(bob),
        fl.BaseDataset(th.tensor([3, 4, 5, 6]), th.tensor([3, 4, 5,
                                                           6])).send(alice),
        fl.BaseDataset(th.tensor([7, 8, 9, 10]), th.tensor([7, 8, 9,
                                                            10])).send(james),
    ]

    fed_dataset = sy.FederatedDataset(datasets)
    fdataloader = sy.FederatedDataLoader(fed_dataset,
                                         batch_size=2,
                                         iter_per_worker=True,
                                         shuffle=True)
    nr_workers = len(datasets)
    assert (fdataloader.num_iterators == nr_workers
            ), "num_iterators should be equal to number or workers"
    for batch_idx, batches in enumerate(fdataloader):
        assert len(
            batches.keys()) == nr_workers, "return a batch for each worker"
示例#7
0
# Make Syft federated dataset
client_datapair_dict = {}
datasets = []

logging.info("Load federated dataset")
for client_id in client_ids:
    tmp_path = federated_path + '/hospital_' + str(client_id) + '.csv'
    x, y = eICU_data.get_train_data_from_hopital(client_id)
    client_datapair_dict["hospital_{}".format(client_id)] = (x, y)
#     client_data_list.append((pd.read_csv(federated_path + '/hospital_' + str(client_id) + '.csv')[predictive_attributes], )

for client_id in client_ids:
    tmp_tuple = client_datapair_dict["hospital_{}".format(client_id)]
    datasets.append(
        fl.BaseDataset(
            torch.tensor(tmp_tuple[0], dtype=torch.float32),
            torch.tensor(tmp_tuple[1].squeeze(), dtype=torch.long)).send(
                virtual_workers["hospital_{}".format(client_id)]))

fed_dataset = sy.FederatedDataset(datasets)
fdataloader = sy.FederatedDataLoader(fed_dataset,
                                     batch_size=args["batch_size"])

# Load test data
if args['split_strategy'] == 'trainN_testN':
    x, y = eICU_data.get_full_test_data()
if args['split_strategy'] == 'trainNminus1_test1':
    x, y = eICU_data.get_test_data_from_hopital(args['test_hospital_id'])
x_pt = torch.tensor(x, dtype=torch.float32)  # transform to torch tensor
y_pt = torch.tensor(y.squeeze(), dtype=torch.long)
my_dataset = TensorDataset(x_pt, y_pt)  # create your datset
test_loader = DataLoader(my_dataset, batch_size=10)  # create your dataloader