예제 #1
0
    def load_data(self) -> None:
        """Generating data and loading them onto this client."""
        logging.info("[%s] Loading its data source...", self)

        if self.datasource is None or (hasattr(Config().data, 'reload_data')
                                       and Config().data.reload_data):
            self.datasource = datasources_registry.get(
                client_id=self.client_id)

        self.data_loaded = True

        logging.info("[%s] Dataset size: %s", self,
                     self.datasource.num_train_examples())

        # Setting up the data sampler
        self.sampler = samplers_registry.get(self.datasource, self.client_id)

        if hasattr(Config().trainer, 'use_mindspore'):
            # MindSpore requires samplers to be used while constructing
            # the dataset
            self.trainset = self.datasource.get_train_set(self.sampler)
        else:
            # PyTorch uses samplers when loading data with a data loader
            self.trainset = self.datasource.get_train_set()

        if Config().clients.do_test:
            # Set the testset if local testing is needed
            self.testset = self.datasource.get_test_set()
            if hasattr(Config().data, 'testset_sampler'):
                # Set the sampler for test set
                self.testset_sampler = samplers_registry.get(self.datasource,
                                                             self.client_id,
                                                             testing=True)
예제 #2
0
    def load_data(self) -> None:
        """Generating data and loading them onto this client."""
        data_loading_start_time = time.time()
        logging.info("[Client #%d] Loading its data source...", self.client_id)

        self.data_loaded = True

        logging.info("[Client #%d] Dataset size: %s", self.client_id,
                     self.datasource.num_train_examples())

        # Setting up the data sampler
        self.sampler = samplers_registry.get(self.datasource, self.client_id)

        # Setting up the modality sampler
        self.modality_sampler = samplers_registry.multimodal_get(
            datasource=self.datasource, client_id=self.client_id)

        # PyTorch uses samplers when loading data with a data loader
        self.trainset = self.datasource.get_train_set(
            self.modality_sampler.get())

        self.valset = self.datasource.get_val_set()

        if Config().clients.do_test:
            # Set the testset if local testing is needed
            self.testset = self.datasource.get_test_set()

        self.data_loading_time = time.time() - data_loading_start_time
예제 #3
0
    def configure(self):
        """
        Booting the federated learning server by setting up the data, model, and
        creating the clients.
        """
        if Config().is_edge_server():
            logging.info("Configuring edge server #%d as a %s server.",
                         Config().args.id,
                         Config().algorithm.type)
            logging.info("Training with %s local aggregation rounds.",
                         Config().algorithm.local_rounds)

            self.load_trainer()
            self.trainer.set_client_id(Config().args.id)

            # Prepares this server for processors that processes outbound and inbound
            # data payloads
            self.outbound_processor, self.inbound_processor = processor_registry.get(
                "Server", server_id=os.getpid(), trainer=self.trainer)

            if hasattr(Config().server,
                       'edge_do_test') and Config().server.edge_do_test:
                self.datasource = datasources_registry.get(
                    client_id=Config().args.id)
                self.testset = self.datasource.get_test_set()

                if hasattr(Config().data, 'edge_testset_sampler'):
                    # Set the sampler for test set
                    self.testset_sampler = samplers_registry.get(
                        self.datasource, Config().args.id, testing='edge')
                elif hasattr(Config().data, 'testset_size'):
                    # Set the Random Sampler for test set
                    from torch.utils.data import SubsetRandomSampler

                    all_inclusive = range(len(self.datasource.get_test_set()))
                    test_samples = random.sample(all_inclusive,
                                                 Config().data.testset_size)
                    self.testset_sampler = SubsetRandomSampler(test_samples)

            if hasattr(Config(), 'results'):
                result_dir = Config().params['result_dir']
                result_csv_file = f'{result_dir}/edge_{os.getpid()}.csv'
                csv_processor.initialize_csv(result_csv_file,
                                             self.recorded_items, result_dir)

        else:
            super().configure()
            if hasattr(Config().server, 'do_test') and Config().server.do_test:
                self.datasource = datasources_registry.get(client_id=0)
                self.testset = self.datasource.get_test_set()

                if hasattr(Config().data, 'testset_size'):
                    from torch.utils.data import SubsetRandomSampler

                    # Set the sampler for testset
                    all_inclusive = range(len(self.datasource.get_test_set()))
                    test_samples = random.sample(all_inclusive,
                                                 Config().data.testset_size)
                    self.testset_sampler = SubsetRandomSampler(test_samples)
예제 #4
0
    def eval_test(self):
        """ Test if needs to update pruning mask and conduct pruning. """
        if not self.testset_loaded:
            self.datasource = datasources_registry.get(
                client_id=self.client_id)
            self.testset = self.datasource.get_test_set()
            if hasattr(Config().data, 'testset_sampler'):
                # Set the sampler for test set
                self.testset_sampler = samplers_registry.get(self.datasource,
                                                             self.client_id,
                                                             testing=True)
            self.testset_loaded = True

        self.model.eval()

        # Initialize accuracy to be returned to -1, so that the client can disconnect
        # from the server when testing fails
        accuracy = -1

        try:
            if self.testset_sampler is None:
                test_loader = torch.utils.data.DataLoader(
                    self.testset,
                    batch_size=Config().trainer.batch_size,
                    shuffle=False)
            # Use a testing set following the same distribution as the training set
            else:
                test_loader = torch.utils.data.DataLoader(
                    self.testset,
                    batch_size=Config().trainer.batch_size,
                    shuffle=False,
                    sampler=self.testset_sampler)

            correct = 0
            total = 0

            with torch.no_grad():
                for examples, labels in test_loader:
                    examples, labels = examples.to(self.device), labels.to(
                        self.device)

                    outputs = self.model(examples)

                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

            accuracy = correct / total

        except Exception as testing_exception:
            logging.info("Testing on client #%d failed.", self.client_id)
            raise testing_exception

        self.model.cpu()

        return accuracy