def load_data(self) -> None: """Generating data and loading them onto this client.""" logging.info("[%s] Loading its data source...", self) if self.datasource is None or (hasattr(Config().data, 'reload_data') and Config().data.reload_data): self.datasource = datasources_registry.get( client_id=self.client_id) self.data_loaded = True logging.info("[%s] Dataset size: %s", self, self.datasource.num_train_examples()) # Setting up the data sampler self.sampler = samplers_registry.get(self.datasource, self.client_id) if hasattr(Config().trainer, 'use_mindspore'): # MindSpore requires samplers to be used while constructing # the dataset self.trainset = self.datasource.get_train_set(self.sampler) else: # PyTorch uses samplers when loading data with a data loader self.trainset = self.datasource.get_train_set() if Config().clients.do_test: # Set the testset if local testing is needed self.testset = self.datasource.get_test_set() if hasattr(Config().data, 'testset_sampler'): # Set the sampler for test set self.testset_sampler = samplers_registry.get(self.datasource, self.client_id, testing=True)
def load_data(self) -> None: """Generating data and loading them onto this client.""" data_loading_start_time = time.time() logging.info("[Client #%d] Loading its data source...", self.client_id) self.data_loaded = True logging.info("[Client #%d] Dataset size: %s", self.client_id, self.datasource.num_train_examples()) # Setting up the data sampler self.sampler = samplers_registry.get(self.datasource, self.client_id) # Setting up the modality sampler self.modality_sampler = samplers_registry.multimodal_get( datasource=self.datasource, client_id=self.client_id) # PyTorch uses samplers when loading data with a data loader self.trainset = self.datasource.get_train_set( self.modality_sampler.get()) self.valset = self.datasource.get_val_set() if Config().clients.do_test: # Set the testset if local testing is needed self.testset = self.datasource.get_test_set() self.data_loading_time = time.time() - data_loading_start_time
def configure(self): """ Booting the federated learning server by setting up the data, model, and creating the clients. """ if Config().is_edge_server(): logging.info("Configuring edge server #%d as a %s server.", Config().args.id, Config().algorithm.type) logging.info("Training with %s local aggregation rounds.", Config().algorithm.local_rounds) self.load_trainer() self.trainer.set_client_id(Config().args.id) # Prepares this server for processors that processes outbound and inbound # data payloads self.outbound_processor, self.inbound_processor = processor_registry.get( "Server", server_id=os.getpid(), trainer=self.trainer) if hasattr(Config().server, 'edge_do_test') and Config().server.edge_do_test: self.datasource = datasources_registry.get( client_id=Config().args.id) self.testset = self.datasource.get_test_set() if hasattr(Config().data, 'edge_testset_sampler'): # Set the sampler for test set self.testset_sampler = samplers_registry.get( self.datasource, Config().args.id, testing='edge') elif hasattr(Config().data, 'testset_size'): # Set the Random Sampler for test set from torch.utils.data import SubsetRandomSampler all_inclusive = range(len(self.datasource.get_test_set())) test_samples = random.sample(all_inclusive, Config().data.testset_size) self.testset_sampler = SubsetRandomSampler(test_samples) if hasattr(Config(), 'results'): result_dir = Config().params['result_dir'] result_csv_file = f'{result_dir}/edge_{os.getpid()}.csv' csv_processor.initialize_csv(result_csv_file, self.recorded_items, result_dir) else: super().configure() if hasattr(Config().server, 'do_test') and Config().server.do_test: self.datasource = datasources_registry.get(client_id=0) self.testset = self.datasource.get_test_set() if hasattr(Config().data, 'testset_size'): from torch.utils.data import SubsetRandomSampler # Set the sampler for testset all_inclusive = range(len(self.datasource.get_test_set())) test_samples = random.sample(all_inclusive, Config().data.testset_size) self.testset_sampler = SubsetRandomSampler(test_samples)
def eval_test(self): """ Test if needs to update pruning mask and conduct pruning. """ if not self.testset_loaded: self.datasource = datasources_registry.get( client_id=self.client_id) self.testset = self.datasource.get_test_set() if hasattr(Config().data, 'testset_sampler'): # Set the sampler for test set self.testset_sampler = samplers_registry.get(self.datasource, self.client_id, testing=True) self.testset_loaded = True self.model.eval() # Initialize accuracy to be returned to -1, so that the client can disconnect # from the server when testing fails accuracy = -1 try: if self.testset_sampler is None: test_loader = torch.utils.data.DataLoader( self.testset, batch_size=Config().trainer.batch_size, shuffle=False) # Use a testing set following the same distribution as the training set else: test_loader = torch.utils.data.DataLoader( self.testset, batch_size=Config().trainer.batch_size, shuffle=False, sampler=self.testset_sampler) correct = 0 total = 0 with torch.no_grad(): for examples, labels in test_loader: examples, labels = examples.to(self.device), labels.to( self.device) outputs = self.model(examples) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = correct / total except Exception as testing_exception: logging.info("Testing on client #%d failed.", self.client_id) raise testing_exception self.model.cpu() return accuracy