Example #1
0
class FLServer(object):

    MIN_NUM_WORKERS = 4
    # MIN_NUM_WORKERS = 10

    # MAX_NUM_ROUNDS = 100
    MAX_NUM_ROUNDS = 36
    NUM_CLIENTS_CONTACTED_PER_ROUND = 4
    # NUM_CLIENTS_CONTACTED_PER_ROUND = 10

    ROUNDS_BETWEEN_VALIDATIONS = 1

    # LENET5_MODEL_FEMNIST, "127.0.0.1", 5000, gpu, output=args.output, aggregation=args.aggregation
    def __init__(self, global_model, aggregation="normal_atten"):
        # FLServer(GlobalModel_MNIST_CNN, "127.0.0.1", 5000, gpu)

        # os.environ['CUDA_VISIBLE_DEVICES'] = '%d'%gpu
        self.global_model = global_model()
        self.ready_client_sids = set()

        # self.host = host
        # self.port = port
        self.client_resource = {}

        self.wait_time = 0

        self.model_id = str(uuid.uuid4())

        self.aggregation = aggregation
        self.attention_mechanism = Attention()
        #####
        # training states
        self.current_round = -1  # -1 for not yet started
        self.current_round_client_updates = []
        self.eval_client_updates = []
        #####

        self.invalid_tolerate = 0

    def handle_client_update(self, data):

        self.current_round_client_updates = data
        uploaded_weights = [
            x['weights'] for x in self.current_round_client_updates
        ]

        if self.aggregation in ["normal_atten", "atten", "rule_out"]:

            if self.aggregation == "normal_atten":
                # Same atttention
                print("### Update with normal attention mechanism! ###")
                attention = np.tile(np.array([1.0]), len(uploaded_weights))
            else:
                print("### Update with calculated attention mechanism! ###")
                # attention = self.attention_mechanism.cal_weights(np.array( uploaded_weights ))
                attention = self.attention_mechanism.cal_weights(
                    np.array(uploaded_weights))
                print("old attention", attention)
                # type(attention): <class 'numpy.ndarray'> shape (10, )

                if self.aggregation == "rule_out":
                    # Rule out
                    new_attention = np.zeros(attention.shape)
                    for idx in range(len(attention)):
                        if attention[idx] > np.mean(attention):
                            new_attention[idx] = 1.0
                    attention = new_attention
                    print("new attention", attention)

            attack_label = [
                "{}_{}".format(x['attack_mode'], x['assigned_label'])
                for x in self.current_round_client_updates
            ]
            self.global_model.update_weights_with_attention(
                uploaded_weights,
                [x['train_size'] for x in self.current_round_client_updates],
                attention, attack_label)

        else:
            print("### Update with baseline methods! ###")
            self.global_model.update_weights_baseline(
                uploaded_weights,
                [x['train_size']
                 for x in self.current_round_client_updates], self.aggregation)

        aggr_train_loss, aggr_train_accuracy = self.global_model.aggregate_train_loss_accuracy(
            [x['train_loss'] for x in self.current_round_client_updates],
            [x['train_accuracy'] for x in self.current_round_client_updates],
            [x['train_size']
             for x in self.current_round_client_updates], self.current_round)
        if self.global_model.prev_train_loss is not None and self.global_model.prev_train_loss < aggr_train_loss:
            self.invalid_tolerate = self.invalid_tolerate + 1
        else:
            self.invalid_tolerate = 0
        self.global_model.prev_train_loss = aggr_train_loss

    def handle_client_eval(self, data):
        if self.eval_client_updates is None:
            return

        self.eval_client_updates = data

        # tolerate 30% unresponsive clients

        aggr_test_loss, aggr_test_accuracy = self.global_model.aggregate_loss_accuracy(
            [x['test_loss'] for x in self.eval_client_updates],
            [x['test_accuracy'] for x in self.eval_client_updates],
            [x['test_size'] for x in self.eval_client_updates],
        )
        print("\n--------Aggregating test loss---------\n")
        print("aggr_test_loss", aggr_test_loss)
        print("aggr_test_accuracy", aggr_test_accuracy)
        print("best model at round ", self.global_model.best_round,
              ", get the best loss ", self.global_model.best_loss)
        print("== done ==")
        self.eval_client_updates = None  # special value, forbid evaling again

    # Note: we assume that during training the #workers will be >= MIN_NUM_WORKERS
    def train_next_round(self, clients):

        self.current_round += 1
        # buffers all client updates
        self.current_round_client_updates = []

        print("\n ### Round ", self.current_round, "### \n")

        # print("request updates from", client_sids_selected)
        # by default each client cnn is in its own "room"

        # path = os.path.join("../",'saved_weights', 'iteration_' + str(self.current_round))
        # if not os.path.exists(path):
        #     os.makedirs(path)
        # np.save( os.path.join(path, "server_weights"), self.global_model.current_weights)

        train_next_round_info = {
            'model_id':
            self.model_id,
            'round_number':
            self.current_round,
            # 'current_weights': obj_to_pickle_string(self.global_model.current_weights),
            'current_weights':
            self.global_model.current_weights,
            'weights_format':
            'not pickle',
            'run_validation':
            self.current_round % FLServer.ROUNDS_BETWEEN_VALIDATIONS == 0,
        }
        return train_next_round_info

    def stop_and_eval(self):
        self.eval_client_updates = []
        for rid in self.ready_client_sids:
            #emit('stop_and_eval', {
            #		'model_id': self.model_id,
            #		'current_weights': obj_to_pickle_string(self.global_model.current_weights),
            #		'weights_format': 'pickle'
            #	}, room=rid)
            self.emit(
                'stop_and_eval',
                {
                    'model_id': self.model_id,
                    # 'current_weights': obj_to_pickle_string(self.global_model.best_weight),
                    'current_weights': self.global_model.best_weight,
                    'weights_format': 'not pickle'
                },
                room=rid)