Пример #1
0
def test_federated_dataset_search(workers):

    bob = workers["bob"]
    alice = workers["alice"]

    grid = sy.PrivateGridNetwork(*[bob, alice])

    train_bob = th.Tensor(th.zeros(1000, 100)).tag("data").send(bob)
    target_bob = th.Tensor(th.zeros(1000, 100)).tag("target").send(bob)

    train_alice = th.Tensor(th.zeros(1000, 100)).tag("data").send(alice)
    target_alice = th.Tensor(th.zeros(1000, 100)).tag("target").send(alice)

    data = grid.search("data")
    target = grid.search("target")

    datasets = [
        BaseDataset(data["bob"][0], target["bob"][0]),
        BaseDataset(data["alice"][0], target["alice"][0]),
    ]

    fed_dataset = sy.FederatedDataset(datasets)
    train_loader = sy.FederatedDataLoader(fed_dataset,
                                          batch_size=4,
                                          shuffle=False,
                                          drop_last=False)

    counter = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        counter += 1

    assert counter == len(train_loader), f"{counter} == {len(fed_dataset)}"
Пример #2
0
def test_dataset_to_federate(workers):
    bob = workers["bob"]
    alice = workers["alice"]

    dataset = BaseDataset(th.tensor([1.0, 2, 3, 4, 5, 6]),
                          th.tensor([1.0, 2, 3, 4, 5, 6]))

    fed_dataset = dataset.federate((bob, alice))

    assert isinstance(fed_dataset, sy.FederatedDataset)

    assert fed_dataset.workers == ["bob", "alice"]
    assert fed_dataset["bob"].location.id == "bob"
    assert len(fed_dataset) == 6
Пример #3
0
def test_base_dataset(workers):

    bob = workers["bob"]
    inputs = th.tensor([1, 2, 3, 4.0])
    targets = th.tensor([1, 2, 3, 4.0])
    dataset = BaseDataset(inputs, targets)

    assert len(dataset) == 4
    assert dataset[2] == (3, 3)

    dataset = dataset.send(bob)
    assert dataset.data.location.id == "bob"
    assert dataset.targets.location.id == "bob"
    assert dataset.location.id == "bob"
Пример #4
0
def test_illegal_get(workers):
    """test getting error message when calling .get() on a dataset that's a part of fedratedDataset object"""
    bob = workers["bob"]
    alice = workers["alice"]

    alice_base_dataset = BaseDataset(th.tensor([3, 4, 5, 6]),
                                     th.tensor([3, 4, 5, 6]))
    datasets = [
        BaseDataset(th.tensor([1, 2]), th.tensor([1, 2])).send(bob),
        alice_base_dataset.send(alice),
    ]
    fed_dataset = sy.FederatedDataset(datasets)
    with pytest.raises(ValueError):
        fed_dataset["alice"].get()
Пример #5
0
def test_get_dataset(workers):
    bob = workers["bob"]
    alice = workers["alice"]

    alice_base_dataset = BaseDataset(th.tensor([3, 4, 5, 6]),
                                     th.tensor([3, 4, 5, 6]))
    datasets = [
        BaseDataset(th.tensor([1, 2]), th.tensor([1, 2])).send(bob),
        alice_base_dataset.send(alice),
    ]
    fed_dataset = sy.FederatedDataset(datasets)
    dataset = fed_dataset.get_dataset("alice")

    assert len(fed_dataset) == 2
    assert len(dataset) == 4
Пример #6
0
def test_abstract_dataset():
    inputs = th.tensor([1, 2, 3, 4.0])
    targets = th.tensor([1, 2, 3, 4.0])
    dataset = BaseDataset(inputs, targets, id=1)

    assert dataset.id == 1
    assert dataset.description == None
Пример #7
0
def test_base_dataset_transform():

    inputs = th.tensor([1, 2, 3, 4.0])
    targets = th.tensor([1, 2, 3, 4.0])

    transform_dataset = BaseDataset(inputs, targets)

    def func(x):

        return x * 2

    transform_dataset.transform(func)

    expected_val = th.tensor([2, 4, 6, 8])
    transformed_val = [val[0].item() for val in transform_dataset]

    assert expected_val.equal(th.tensor(transformed_val).long())
Пример #8
0
def test_federated_dataset(workers):
    bob = workers["bob"]
    alice = workers["alice"]

    alice_base_dataset = BaseDataset(th.tensor([3, 4, 5, 6]),
                                     th.tensor([3, 4, 5, 6]))
    datasets = [
        BaseDataset(th.tensor([1, 2]), th.tensor([1, 2])).send(bob),
        alice_base_dataset.send(alice),
    ]

    fed_dataset = sy.FederatedDataset(datasets)

    assert fed_dataset.workers == ["bob", "alice"]
    assert len(fed_dataset) == 6

    alice_remote_data = fed_dataset.get_dataset("alice")
    assert (alice_remote_data.data == alice_base_dataset.data).all()
    assert alice_remote_data[2] == (5, 5)
    assert len(alice_remote_data) == 4
    assert len(fed_dataset) == 2

    assert isinstance(fed_dataset.__str__(), str)
Пример #9
0
def test_base_dataset(workers):

    bob = workers["bob"]
    inputs = th.tensor([1, 2, 3, 4.0])
    targets = th.tensor([1, 2, 3, 4.0])
    dataset = BaseDataset(inputs, targets)

    assert len(dataset) == 4
    assert dataset[2] == (3, 3)

    dataset.send(bob)
    assert dataset.data.location.id == "bob"
    assert dataset.targets.location.id == "bob"
    assert dataset.location.id == "bob"

    dataset.get()
    with pytest.raises(AttributeError):
        assert dataset.data.location.id == 0
    with pytest.raises(AttributeError):
        assert dataset.targets.location.id == 0
Пример #10
0
def federate_dataset(dataset, workers, classnum, scheme, 
                     class_per_worker=1, uniqueness_threshold=0, custom_mapping=None, use_pysyft=True):
    """Adapted from https://github.com/OpenMined/PySyft/blob/master/syft/frameworks/torch/fl/dataset.py
    dataset: pytorch Dataset
    workers: List[Any], can be of pysyft virtualworkers
    Assert shard_per_class * len(classes) == class_per_worker * len(workers)
    classnum: int, number of classes
    scheme: can be 'naive', 'permuted', 'choose-unique', 'custom'
    custom_mapping: when scheme is 'custom', provide Dict[worker, List[int]
    Example configurations:
    1) 1 worker to 1 *full* class, implies len(classes)==len(workers)  (1,1)
    2) 1 worker to 2 *full* classes, implies len(classes)==2*len(workers) (1,2)
    3) 1 worker to 1 class, a class has 2 shards, implies 2*len(classes)==len(workers) (2,1)
    4) 1 worker to 1 class, a class has n shards, implies n*len(classes)==len(workers) (n,1)
    5) 1 worker to 2 classes, a class has 2 shards, implies len(classes)==len(workers) but a worker
       should not have the same 2 classes (2,2)
    6) ... vice versa
    TODO: this implementation requires len(classes) be divisible by class_per_worker...
    
    Returns a Pysyft FederatedDataset or a Dict {worker_id:int : [x:Tensor, y:Tensor]}
      AND a meta-data object
    """
    
    assert (class_per_worker * len(workers)) % classnum == 0, \
           "class per worker (%d) * number of workers (%d) must be a multiple of number of classes (%d)" \
           % (class_per_worker, len(workers), classnum)
    assert classnum % class_per_worker == 0, \
           "Limitation: requires number of classes (%d) be a multiple of class per worker (%d)" \
           % (classnum, class_per_worker)
    
    shard_per_class = int(class_per_worker * len(workers) / classnum)
    meta_data = {}
    with torch.no_grad():
        # Prepare the shards
        gen = dataset.enumerate_by_class() #input:tensor, target:int, output:tensor
        class_shards = defaultdict(list)
        class_index_map = []
        for data in gen:
            class_idx = data[0]['target']
            class_index_map.append(class_idx)
            x = torch.stack(list(map(lambda d: d['input'], data)), 0)
            y = torch.stack(list(map(lambda d: d['output'], data)), 0)
            shard_size = math.ceil(len(x) / shard_per_class)
            x = torch.split(x, shard_size)
            y = torch.split(y, shard_size) # tuple
            for shard in zip(x, y):
                class_shards[class_idx].append((shard[0], shard[1]))
        
        # Distribute and send shards
        class_indices = list(class_shards.keys())
        range_temp = None
        fed_datasets= {}
        
        for idx, worker in enumerate(workers):            
            # Determine classes for worker
            if scheme == 'naive':
                offset = math.floor(idx / shard_per_class) * class_per_worker
                class_indices_range = range(offset, offset+class_per_worker)
            elif scheme == 'permuted':
                if not range_temp:
                    range_temp = np.split(np.random.permutation(classnum), int(classnum / class_per_worker))
                class_indices_range = range_temp.pop()
            elif scheme == 'choose-unique':
                # NOTE: careful of upper bound for # workers = {classnum} C {class_per_worker}
                # NOTE: if parameters too stringent can loop forever... need mathematical upper bound?
                if not range_temp:
                    range_temp = {'history': [], 'buffer': []}
                if not range_temp['buffer']:
                    range_temp['buffer'] = list(range(classnum))
                    
                candidate = []
                while len(candidate) == 0 or \
                      set(candidate) in range_temp['history'] or \
                      any([len(set(candidate)&set(h)) > uniqueness_threshold for h in range_temp['history']]):
                    candidate = np.random.choice(range_temp['buffer'], class_per_worker, replace=False)
                    
                range_temp['buffer'] = [e for e in range_temp['buffer'] if e not in candidate]
                range_temp['history'].append(set(candidate))
                class_indices_range = candidate
            elif scheme == 'custom':
                assert custom_mapping is not None
                map_index = worker.id if isinstance(worker, BaseWorker) else worker
                assert map_index in custom_mapping, "Worker id %s not in custom mapping %s" % (str(map_index), str(list(custom_mapping.keys())))
                class_indices_range = [class_index_map.index(cls) for cls in custom_mapping[map_index]]
            else:
                raise ValueError('Bad scheme provided for federate_dataset')
        
            # Collect shards
            xs = []
            ys = []
            for class_index in class_indices_range:
                shard = class_shards[class_index_map[class_index]].pop()
                xs.append(shard[0])
                ys.append(shard[1])
            xs = torch.cat(xs, 0)
            ys = torch.cat(ys, 0)
            
            fed_datasets[worker] = [xs, ys]
            
            # Get some meta-data
            meta_data[worker] = {
                'xshape': list(fed_datasets[worker][0].shape),
                'yshape': list(fed_datasets[worker][1].shape),
                'yset': torch.unique(fed_datasets[worker][1], sorted=True).int().tolist()
            }
        # End looping over workers
    # End torch.no_grad
    
    
    
    
    # Send shards if Pysyft
    if use_pysyft:
        for worker in fed_datasets:
            print("Sending data to worker %s ..." % worker.id)
            fed_datasets[worker][0] = fed_datasets[worker][0].send(worker)
            fed_datasets[worker][1] = fed_datasets[worker][1].send(worker)
        return FederatedDataset([BaseDataset(data[0], data[1]) for data in fed_datasets.values()]), meta_data
    else:
        return fed_datasets, meta_data