def test_federated_dataset_search(workers): bob = workers["bob"] alice = workers["alice"] grid = sy.PrivateGridNetwork(*[bob, alice]) train_bob = th.Tensor(th.zeros(1000, 100)).tag("data").send(bob) target_bob = th.Tensor(th.zeros(1000, 100)).tag("target").send(bob) train_alice = th.Tensor(th.zeros(1000, 100)).tag("data").send(alice) target_alice = th.Tensor(th.zeros(1000, 100)).tag("target").send(alice) data = grid.search("data") target = grid.search("target") datasets = [ BaseDataset(data["bob"][0], target["bob"][0]), BaseDataset(data["alice"][0], target["alice"][0]), ] fed_dataset = sy.FederatedDataset(datasets) train_loader = sy.FederatedDataLoader(fed_dataset, batch_size=4, shuffle=False, drop_last=False) counter = 0 for batch_idx, (data, target) in enumerate(train_loader): counter += 1 assert counter == len(train_loader), f"{counter} == {len(fed_dataset)}"
def test_illegal_get(workers): """test getting error message when calling .get() on a dataset that's a part of fedratedDataset object""" bob = workers["bob"] alice = workers["alice"] alice_base_dataset = BaseDataset(th.tensor([3, 4, 5, 6]), th.tensor([3, 4, 5, 6])) datasets = [ BaseDataset(th.tensor([1, 2]), th.tensor([1, 2])).send(bob), alice_base_dataset.send(alice), ] fed_dataset = sy.FederatedDataset(datasets) with pytest.raises(ValueError): fed_dataset["alice"].get()
def test_get_dataset(workers): bob = workers["bob"] alice = workers["alice"] alice_base_dataset = BaseDataset(th.tensor([3, 4, 5, 6]), th.tensor([3, 4, 5, 6])) datasets = [ BaseDataset(th.tensor([1, 2]), th.tensor([1, 2])).send(bob), alice_base_dataset.send(alice), ] fed_dataset = sy.FederatedDataset(datasets) dataset = fed_dataset.get_dataset("alice") assert len(fed_dataset) == 2 assert len(dataset) == 4
def test_abstract_dataset(): inputs = th.tensor([1, 2, 3, 4.0]) targets = th.tensor([1, 2, 3, 4.0]) dataset = BaseDataset(inputs, targets, id=1) assert dataset.id == 1 assert dataset.description == None
def test_base_dataset(workers): bob = workers["bob"] inputs = th.tensor([1, 2, 3, 4.0]) targets = th.tensor([1, 2, 3, 4.0]) dataset = BaseDataset(inputs, targets) assert len(dataset) == 4 assert dataset[2] == (3, 3) dataset = dataset.send(bob) assert dataset.data.location.id == "bob" assert dataset.targets.location.id == "bob" assert dataset.location.id == "bob"
def test_dataset_to_federate(workers): bob = workers["bob"] alice = workers["alice"] dataset = BaseDataset(th.tensor([1.0, 2, 3, 4, 5, 6]), th.tensor([1.0, 2, 3, 4, 5, 6])) fed_dataset = dataset.federate((bob, alice)) assert isinstance(fed_dataset, sy.FederatedDataset) assert fed_dataset.workers == ["bob", "alice"] assert fed_dataset["bob"].location.id == "bob" assert len(fed_dataset) == 6
def test_federated_dataset(workers): bob = workers["bob"] alice = workers["alice"] alice_base_dataset = BaseDataset(th.tensor([3, 4, 5, 6]), th.tensor([3, 4, 5, 6])) datasets = [ BaseDataset(th.tensor([1, 2]), th.tensor([1, 2])).send(bob), alice_base_dataset.send(alice), ] fed_dataset = sy.FederatedDataset(datasets) assert fed_dataset.workers == ["bob", "alice"] assert len(fed_dataset) == 6 alice_remote_data = fed_dataset.get_dataset("alice") assert (alice_remote_data.data == alice_base_dataset.data).all() assert alice_remote_data[2] == (5, 5) assert len(alice_remote_data) == 4 assert len(fed_dataset) == 2 assert isinstance(fed_dataset.__str__(), str)
def test_base_dataset_transform(): inputs = th.tensor([1, 2, 3, 4.0]) targets = th.tensor([1, 2, 3, 4.0]) transform_dataset = BaseDataset(inputs, targets) def func(x): return x * 2 transform_dataset.transform(func) expected_val = th.tensor([2, 4, 6, 8]) transformed_val = [val[0].item() for val in transform_dataset] assert expected_val.equal(th.tensor(transformed_val).long())
def test_base_dataset(workers): bob = workers["bob"] inputs = th.tensor([1, 2, 3, 4.0]) targets = th.tensor([1, 2, 3, 4.0]) dataset = BaseDataset(inputs, targets) assert len(dataset) == 4 assert dataset[2] == (3, 3) dataset.send(bob) assert dataset.data.location.id == "bob" assert dataset.targets.location.id == "bob" assert dataset.location.id == "bob" dataset.get() with pytest.raises(AttributeError): assert dataset.data.location.id == 0 with pytest.raises(AttributeError): assert dataset.targets.location.id == 0
def federate_dataset(dataset, workers, classnum, scheme, class_per_worker=1, uniqueness_threshold=0, custom_mapping=None, use_pysyft=True): """Adapted from https://github.com/OpenMined/PySyft/blob/master/syft/frameworks/torch/fl/dataset.py dataset: pytorch Dataset workers: List[Any], can be of pysyft virtualworkers Assert shard_per_class * len(classes) == class_per_worker * len(workers) classnum: int, number of classes scheme: can be 'naive', 'permuted', 'choose-unique', 'custom' custom_mapping: when scheme is 'custom', provide Dict[worker, List[int] Example configurations: 1) 1 worker to 1 *full* class, implies len(classes)==len(workers) (1,1) 2) 1 worker to 2 *full* classes, implies len(classes)==2*len(workers) (1,2) 3) 1 worker to 1 class, a class has 2 shards, implies 2*len(classes)==len(workers) (2,1) 4) 1 worker to 1 class, a class has n shards, implies n*len(classes)==len(workers) (n,1) 5) 1 worker to 2 classes, a class has 2 shards, implies len(classes)==len(workers) but a worker should not have the same 2 classes (2,2) 6) ... vice versa TODO: this implementation requires len(classes) be divisible by class_per_worker... Returns a Pysyft FederatedDataset or a Dict {worker_id:int : [x:Tensor, y:Tensor]} AND a meta-data object """ assert (class_per_worker * len(workers)) % classnum == 0, \ "class per worker (%d) * number of workers (%d) must be a multiple of number of classes (%d)" \ % (class_per_worker, len(workers), classnum) assert classnum % class_per_worker == 0, \ "Limitation: requires number of classes (%d) be a multiple of class per worker (%d)" \ % (classnum, class_per_worker) shard_per_class = int(class_per_worker * len(workers) / classnum) meta_data = {} with torch.no_grad(): # Prepare the shards gen = dataset.enumerate_by_class() #input:tensor, target:int, output:tensor class_shards = defaultdict(list) class_index_map = [] for data in gen: class_idx = data[0]['target'] class_index_map.append(class_idx) x = torch.stack(list(map(lambda d: d['input'], data)), 0) y = torch.stack(list(map(lambda d: d['output'], data)), 0) shard_size = math.ceil(len(x) / shard_per_class) x = torch.split(x, shard_size) y = torch.split(y, shard_size) # tuple for shard in zip(x, y): class_shards[class_idx].append((shard[0], shard[1])) # Distribute and send shards class_indices = list(class_shards.keys()) range_temp = None fed_datasets= {} for idx, worker in enumerate(workers): # Determine classes for worker if scheme == 'naive': offset = math.floor(idx / shard_per_class) * class_per_worker class_indices_range = range(offset, offset+class_per_worker) elif scheme == 'permuted': if not range_temp: range_temp = np.split(np.random.permutation(classnum), int(classnum / class_per_worker)) class_indices_range = range_temp.pop() elif scheme == 'choose-unique': # NOTE: careful of upper bound for # workers = {classnum} C {class_per_worker} # NOTE: if parameters too stringent can loop forever... need mathematical upper bound? if not range_temp: range_temp = {'history': [], 'buffer': []} if not range_temp['buffer']: range_temp['buffer'] = list(range(classnum)) candidate = [] while len(candidate) == 0 or \ set(candidate) in range_temp['history'] or \ any([len(set(candidate)&set(h)) > uniqueness_threshold for h in range_temp['history']]): candidate = np.random.choice(range_temp['buffer'], class_per_worker, replace=False) range_temp['buffer'] = [e for e in range_temp['buffer'] if e not in candidate] range_temp['history'].append(set(candidate)) class_indices_range = candidate elif scheme == 'custom': assert custom_mapping is not None map_index = worker.id if isinstance(worker, BaseWorker) else worker assert map_index in custom_mapping, "Worker id %s not in custom mapping %s" % (str(map_index), str(list(custom_mapping.keys()))) class_indices_range = [class_index_map.index(cls) for cls in custom_mapping[map_index]] else: raise ValueError('Bad scheme provided for federate_dataset') # Collect shards xs = [] ys = [] for class_index in class_indices_range: shard = class_shards[class_index_map[class_index]].pop() xs.append(shard[0]) ys.append(shard[1]) xs = torch.cat(xs, 0) ys = torch.cat(ys, 0) fed_datasets[worker] = [xs, ys] # Get some meta-data meta_data[worker] = { 'xshape': list(fed_datasets[worker][0].shape), 'yshape': list(fed_datasets[worker][1].shape), 'yset': torch.unique(fed_datasets[worker][1], sorted=True).int().tolist() } # End looping over workers # End torch.no_grad # Send shards if Pysyft if use_pysyft: for worker in fed_datasets: print("Sending data to worker %s ..." % worker.id) fed_datasets[worker][0] = fed_datasets[worker][0].send(worker) fed_datasets[worker][1] = fed_datasets[worker][1].send(worker) return FederatedDataset([BaseDataset(data[0], data[1]) for data in fed_datasets.values()]), meta_data else: return fed_datasets, meta_data