Пример #1
0
    def fit(self, ins: flwr.FitIns) -> flwr.FitRes:
        weights: flwr.Weights = flwr.parameters_to_weights(ins[0])
        config = ins[1]
        log(
            DEBUG,
            "fit on %s (examples: %s), config %s",
            self.cid,
            self.num_examples_train,
            config,
        )

        # Training configuration
        # epoch_global = int(config["epoch_global"])
        epochs = int(config["epochs"])
        batch_size = int(config["batch_size"])
        # lr_initial = float(config["lr_initial"])
        # lr_decay = float(config["lr_decay"])
        timeout = int(config["timeout"])
        partial_updates = bool(int(config["partial_updates"]))

        # Use provided weights to update the local model
        self.model.set_weights(weights)

        # Train the local model using the local dataset
        completed, fit_duration, num_examples = custom_fit(
            model=self.model,
            dataset=self.ds_train,
            num_epochs=epochs,
            batch_size=batch_size,
            callbacks=[],
            delay_factor=self.delay_factor,
            timeout=timeout,
        )
        log(DEBUG, "client %s had fit_duration %s", self.cid, fit_duration)

        # Compute the maximum number of examples which could have been processed
        num_examples_ceil = self.num_examples_train * epochs

        # Return empty update if local update could not be completed in time
        if not completed and not partial_updates:
            parameters = flwr.weights_to_parameters([])
            return parameters, num_examples, num_examples_ceil

        # Return the refined weights and the number of examples used for training
        parameters = flwr.weights_to_parameters(self.model.get_weights())
        return parameters, num_examples, num_examples_ceil
Пример #2
0
    def test_fit(self):
        """This test is currently quite simple and should be improved"""
        # Prepare
        client = GrpcClientProxy(cid="1", bridge=self.bridge_mock)
        parameters = flower.weights_to_parameters([np.ones((2, 2))])
        ins: flower.FitIns = (parameters, {})

        # Execute
        parameters_prime, num_examples, _ = client.fit(ins=ins)

        # Assert
        assert parameters_prime.tensor_type == "np"
        assert flower.parameters_to_weights(parameters_prime) == []
        assert num_examples == 10
Пример #3
0
    def fit(self, ins: fl.FitIns) -> fl.FitRes:
        weights: fl.Weights = fl.parameters_to_weights(ins[0])
        config = ins[1]

        # Get training
        epochs = int(config["epochs"])
        batch_size = int(config["batch_size"])

        # Use provided weights to update the local model
        self.model.set_weights(weights)

        # Train the local model using the local dataset
        self.model.fit(self.x_train,
                       self.y_train,
                       epochs=epochs,
                       batch_size=batch_size,
                       verbose=2)

        # Return the refined weights and the number of examples used for training
        weights_prime = fl.weights_to_parameters(self.model.get_weights())
        num_examples = len(self.x_train)
        return weights_prime, num_examples, num_examples
Пример #4
0
 def get_parameters(self) -> flwr.ParametersRes:
     parameters = flwr.weights_to_parameters(self.model.get_weights())
     return flwr.ParametersRes(parameters=parameters)