Esempio n. 1
0
    def accuracy_fed_avg(self, reports):
        import fl_model  # pylint: disable=import-error

        # Extract updates from reports
        updates = self.extract_client_updates(reports)

        # Extract client accuracies
        accuracies = np.array([report.accuracy for report in reports])

        # Determine weighting based on accuracies
        factor = 8  # Exponentiation factor
        w = accuracies**factor / sum(accuracies**factor)

        # Perform weighted averaging
        avg_update = [
            torch.zeros(x.size())  # pylint: disable=no-member
            for _, x in updates[0]
        ]
        for i, update in enumerate(updates):
            for j, (_, delta) in enumerate(update):
                # Use weighted average by magnetude of updates
                avg_update[j] += delta * w[i]

        # Extract baseline model weights
        baseline_weights = fl_model.extract_weights(self.model)

        # Load updated weights into model
        updated_weights = []
        for i, (name, weight) in enumerate(baseline_weights):
            updated_weights.append((name, weight + avg_update[i]))

        return updated_weights
Esempio n. 2
0
    def federated_averaging(self, reports):
        import fl_model  # pylint: disable=import-error

        # Extract updates from reports
        updates = self.extract_client_updates(reports)

        # Extract total number of samples
        total_samples = sum([report.num_samples for report in reports])

        # Perform weighted averaging
        avg_update = [
            torch.zeros(x.size())  # pylint: disable=no-member
            for _, x in updates[0]
        ]
        for i, update in enumerate(updates):
            num_samples = reports[i].num_samples
            for j, (_, delta) in enumerate(update):
                # Use weighted average by number of samples
                avg_update[j] += delta * (num_samples / total_samples)

        # Extract baseline model weights
        baseline_weights = fl_model.extract_weights(self.model)

        # Load updated weights into model
        updated_weights = []
        for i, (name, weight) in enumerate(baseline_weights):
            updated_weights.append((name, weight + avg_update[i]))

        return updated_weights
Esempio n. 3
0
    def magnetude_fed_avg(self, reports):
        import fl_model  # pylint: disable=import-error

        # Extract updates from reports
        updates = self.extract_client_updates(reports)

        # Extract update magnetudes
        magnetudes = []
        for update in updates:
            magnetude = 0
            for _, weight in update:
                magnetude += weight.norm() ** 2
            magnetudes.append(np.sqrt(magnetude))

        # Perform weighted averaging
        avg_update = [torch.zeros(x.size())  # pylint: disable=no-member
                      for _, x in updates[0]]
        for i, update in enumerate(updates):
            for j, (_, delta) in enumerate(update):
                # Use weighted average by magnetude of updates
                avg_update[j] += delta * (magnetudes[i] / sum(magnetudes))

        # Extract baseline model weights
        baseline_weights = fl_model.extract_weights(self.model)

        # Load updated weights into model
        updated_weights = []
        for i, (name, weight) in enumerate(baseline_weights):
            updated_weights.append((name, weight + avg_update[i]))

        return updated_weights
Esempio n. 4
0
    def extract_client_updates(self, reports):
        import fl_model  # pylint: disable=import-error

        # Extract baseline model weights
        baseline_weights = fl_model.extract_weights(self.model)

        # Extract weights from reports
        weights = [report.weights for report in reports]

        # Calculate updates from weights
        updates = []
        for weight in weights:
            update = []
            for i, (name, weight) in enumerate(weight):
                bl_name, baseline = baseline_weights[i]

                # Ensure correct weight is being updated
                assert name == bl_name

                # Calculate update
                delta = weight - baseline
                update.append((name, delta))
            updates.append(update)

        return updates
Esempio n. 5
0
    def save_reports(self, round, reports):
        import fl_model  # pylint: disable=import-error

        if reports:
            self.saved_reports['round{}'.format(round)] = [(report.client_id, self.flatten_weights(
                report.weights)) for report in reports]

        # Extract global weights
        self.saved_reports['w{}'.format(round)] = self.flatten_weights(
            fl_model.extract_weights(self.model))
Esempio n. 6
0
    def profiling(self):
        import fl_model  # pylint: disable=import-error

        # Use all clients for profiling
        clients = self.clients

        # Configure clients for training
        self.configuration(clients)

        # Train on clients to generate profile weights
        threads = [Thread(target=client.train) for client in self.clients]
        [t.start() for t in threads]
        [t.join() for t in threads]

        # Recieve client reports
        reports = self.reporting(clients)

        # Extract weights from reports
        weights = [report.weights for report in reports]
        weights = [self.flatten_weights(weight) for weight in weights]

        # Extract initial model weights
        w0 = self.flatten_weights(fl_model.extract_weights(self.model))

        # Save as initial previous model weights
        self.w_previous = w0.copy()

        # Update initial model using results of profiling
        # Perform weight aggregation
        logging.info('Aggregating updates')
        updated_weights = self.aggregation(reports)

        # Load updated weights
        fl_model.load_weights(self.model, updated_weights)

        # Calculate direction vectors (directors)
        directors = [(w - w0) for w in weights]
        # Normalize directors to unit length
        directors = [d / np.sqrt(np.dot(d, d)) for d in directors]

        # Initialize punishment factors
        self.punishment = [0 for _ in range(len(clients))]

        # Use directors for client profiles
        self.profiles = [(client, directors[i])
                         for i, client in enumerate(clients)]
        return self.profiles
Esempio n. 7
0
    def selection(self):
        import fl_model  # pylint: disable=import-error

        clients = self.clients
        clients_per_round = self.config.clients.per_round
        profiles = self.profiles
        w_previous = self.w_previous

        # Extract directors from profiles
        directors = [d for _, d in profiles]

        # Extract most recent model weights
        w_current = self.flatten_weights(fl_model.extract_weights(self.model))
        model_direction = w_current - w_previous
        # Normalize model direction
        model_direction = model_direction / \
            np.sqrt(np.dot(model_direction, model_direction))

        # Update previous model weights
        self.w_previous = w_current

        # Generate client director scores (closer direction is better)
        scores = [np.dot(director, model_direction) for director in directors]
        # Apply punishment for repeatedly selected clients
        p = self.punishment
        scores = [x * (0.9)**p[i] for i, x in enumerate(scores)]

        # Select clients with highest scores
        sample_clients_index = []
        for _ in range(clients_per_round):
            top_score_index = scores.index(max(scores))
            sample_clients_index.append(top_score_index)
            # Overwrite to avoid reselection
            scores[top_score_index] = min(scores) - 1

        # Extract selected sample clients
        sample_clients = [clients[i] for i in sample_clients_index]

        # Update punishment factors
        self.punishment = [
            p[i] + 1 if i in sample_clients_index else 0
            for i in range(len(clients))
        ]

        return sample_clients
Esempio n. 8
0
    def train(self):
        import fl_model  # pylint: disable=import-error

        logging.info('Training on client #{}'.format(self.client_id))

        # Perform model training
        trainloader = fl_model.get_trainloader(self.trainset, self.batch_size)
        fl_model.train(self.model, trainloader, self.optimizer, self.epochs)

        # Extract model weights and biases
        weights = fl_model.extract_weights(self.model)

        # Generate report for server
        self.report = Report(self)
        self.report.weights = weights

        # Perform model testing if applicable
        if self.do_test:
            testloader = fl_model.get_testloader(self.testset, 1000)
            self.report.accuracy = fl_model.test(self.model, testloader)