Exemplo n.º 1
0
        def on_request_update(*args):
            req = args[0]
            # req:
            #     'model_id'
            #     'round_number'
            #     'current_weights'
            #     'weights_format'
            #     'run_validation'
            print("update requested")

            if req['weights_format'] == 'pickle':
                weights = pickle_string_to_obj(req['current_weights'])

            self.local_model.set_weights(weights)
            my_weights, train_loss, train_accuracy = self.local_model.train_one_round()
            resp = {
                'round_number': req['round_number'],
                'weights': obj_to_pickle_string(my_weights),
                'train_size': 10,#self.local_model.x_train.shape[0],
                'valid_size': 0,#self.local_model.x_valid.shape[0],
                'train_loss': train_loss,
                'train_accuracy': train_accuracy,
            }
            if req['run_validation']:
                valid_loss, valid_accuracy = self.local_model.validate()
                resp['valid_loss'] = valid_loss
                resp['valid_accuracy'] = valid_accuracy

            self.sio.emit('client_update', resp)
Exemplo n.º 2
0
        def on_request_update(*args):  # 收到 request_update
            req = args[0]  # 获得数据 包含当前模型权重
            # req:
            #     'model_id'
            #     'round_number'
            #     'current_weights'
            #     'weights_format'
            #     'run_validation'
            print("update requested")
            if FederatedClient.SLEEP_MODE:
                self.intermittently_sleep(p=1., low=10,
                                          high=100)  # randowm sleep

            if req['weights_format'] == 'pickle':  # 解码 获得模型权重
                weights = pickle_string_to_obj(req['current_weights'])

            self.local_model.set_weights(weights)  # 设定权重
            my_weights, train_loss, train_accuracy = self.local_model.train_one_round(
            )  # 获得局部训练集上的表现
            resp = {
                'round_number': req['round_number'],
                'weights': obj_to_pickle_string(my_weights),
                'train_size': self.local_model.x_train.shape[0],
                'valid_size': self.local_model.x_valid.shape[0],
                'train_loss': train_loss,
                'train_accuracy': train_accuracy,
            }
            if req['run_validation']:  # 是否需要测 验证集
                valid_loss, valid_accuracy = self.local_model.validate()
                resp['valid_loss'] = valid_loss
                resp['valid_accuracy'] = valid_accuracy

            self.sio.emit('client_update', resp)  # 回复 client_update事件 附带所需的数据
Exemplo n.º 3
0
        def on_request_update(*args):
            req = args[0]
            # req:
            #     'model_id'
            #     'round_number'
            #     'current_weights'
            #     'weights_format'
            #     'run_validation'
            print("update requested")
            time1 = time.time()

            #if req['weights_format'] == 'pickle':
            #    weights = pickle_string_to_obj(req['current_weights'])
            weights = pickle_string_to_obj(req['current_weights'])
            r = req['round_number']

            self.local_model.set_weights(weights)
            my_weights, train_loss, train_accuracy = self.local_model.train_one_round(
                r, self.cid)
            resp = {
                'round_number': req['round_number'],
                'weights': obj_to_pickle_string(my_weights),
                'train_size': self.local_model.x_train.shape[0],
                'valid_size': self.local_model.x_valid.shape[0],
                'train_loss': train_loss,
                'train_accuracy': train_accuracy,
            }
            time2 = time.time()

            #print("local training time is ", time2-time1)
            # if req['run_validation']:
            #     valid_loss, valid_accuracy = self.local_model.validate()
            #     resp['valid_loss'] = valid_loss
            #     resp['valid_accuracy'] = valid_accuracy

            self.sio.emit('client_update', resp)
        def on_request_update(*args):

            time_start_request_update = time.time()
            with open(self.fo_name, 'a') as fo:
                fo.write("time_start_request_update:    " +
                         str(time_start_request_update) + "\n")
            print(
                "------------------------------------------------time_start_request_update: ",
                time_start_request_update - self.time_start)

            req = args[0]
            # req:
            #     'model_id'
            #     'round_number'
            #     'current_weights'
            #     'weights_format'
            #     'run_validation'
            print("update requested")

            if req['weights_format'] == 'pickle':
                weights = pickle_string_to_obj(req['current_weights'])

            self.local_model.set_weights(weights)

            time_start_training = time.time()
            my_weights, train_loss, train_accuracy = self.local_model.train_one_round(
            )
            time_end_training = time.time()

            with open(self.f_training_name, 'a') as f_training:
                f_training.write("time_training:    " +
                                 str(time_end_training - time_start_training) +
                                 "\n")

            resp = {
                'round_number': req['round_number'],
                'weights': obj_to_pickle_string(my_weights),
                'train_size': self.local_model.x_train.shape[0],
                'valid_size': self.local_model.x_valid.shape[0],
                'train_loss': train_loss,
                'train_accuracy': train_accuracy,
            }
            if req['run_validation']:
                valid_loss, valid_accuracy = self.local_model.validate()
                resp['valid_loss'] = valid_loss
                resp['valid_accuracy'] = valid_accuracy

            time_start_emit = time.time()
            print(
                "------------------------------------------------time_start_emit: ",
                time_start_emit - self.time_start)

            with open(self.fo_name, 'a') as fo:
                fo.write("time_start_emit:    " + str(time_start_emit) + "\n")

            self.sio.emit('client_update', resp)

            time_finish_emit = time.time()
            #fo.write("time_finish_emit:    " + str(time_finish_emit) + "\n")
            print(
                "------------------------------------------------time_finish_emit: ",
                time_finish_emit - self.time_start)