示例#1
0
文件: job_manager.py 项目: xl3783/GFL
    def generate_job(self,
                     work_mode=WorkModeStrategy.WORKMODE_STANDALONE,
                     fed_strategy=FederateStrategy.FED_AVG,
                     epoch=0,
                     model=None,
                     distillation_alpha=None,
                     l2_dist=False):
        """
        Generate job with user-defined strategy
        :param work_mode:
        :param train_strategy:
        :param fed_strategy:
        :param model:
        :param distillation_alpha:
        :return: job object
        """
        with lock:
            # server_host, job_id, train_strategy, train_model, train_model_class_name, fed_strategy, iterations, distillation_alpha
            if fed_strategy == FederateStrategy.FED_DISTILLATION and distillation_alpha is None:
                raise PFLException(
                    "generate_job() missing 1 positoonal argument: 'distillation_alpha'"
                )
            if epoch == 0:
                raise PFLException(
                    "generate_job() missing 1 positoonal argument: 'epoch'")

            job = Job(None,
                      JobUtils.generate_job_id(),
                      inspect.getsourcefile(model),
                      model.__name__,
                      fed_strategy,
                      epoch,
                      distillation_alpha=distillation_alpha,
                      l2_dist=l2_dist)

            if work_mode == WorkModeStrategy.WORKMODE_STANDALONE:
                job.set_server_host("localhost:8080")
            else:
                job.set_server_host("")

            return job
示例#2
0
 def _generate_new_scheduler(self, model, scheduler):
     scheduler_names = []
     for scheduler_item in SchedulerStrategy.__members__.items():
         scheduler_names.append(scheduler_item.value)
     if scheduler.__class__.__name__ not in scheduler_names:
         raise PFLException("optimizer get wrong type value")
     optimizer = scheduler.__getattribute__("optimizer")
     params = scheduler.state_dict()
     new_optimizer = self._generate_new_optimizer(model, optimizer)
     if isinstance(scheduler, torch.optim.lr_scheduler.CyclicLR):
         return torch.optim.lr_scheduler.CyclicLR(new_optimizer, base_lr=params['base_lrs'],
                                                  max_lr=params['max_lrs'],
                                                  step_size_up=params['total_size'] * params['step_ratio'],
                                                  step_size_down=params['total_size'] - (params['total_size'] *
                                                                                         params['step_ratio']),
                                                  mode=params['mode'], gamma=params['gamma'],
                                                  scale_fn=params['scale_fn'], scale_mode=params['scale_mode'],
                                                  cycle_momentum=params['cycle_momentum'],
                                                  base_momentum=params['base_momentums'],
                                                  max_momentum=params['max_momentums'],
                                                  last_epoch=(-1 if params['last_epoch'] == 0 else params[
                                                      'last_epoch']))
     elif isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingLR):
         return torch.optim.lr_scheduler.CosineAnnealingLR(new_optimizer, T_max=params['T_max'],
                                                           eta_min=params['eta_min'],
                                                           last_epoch=(-1 if params['last_epoch'] == 0 else params[
                                                               'last_epoch']))
     elif isinstance(scheduler, torch.optim.lr_scheduler.ExponentialLR):
         return torch.optim.lr_scheduler.ExponentialLR(new_optimizer, gamma=params['gamma'],
                                                       last_epoch=(-1 if params['last_epoch'] == 0 else params[
                                                           'last_epoch']))
     elif isinstance(scheduler, torch.optim.lr_scheduler.LambdaLR):
         return torch.optim.lr_scheduler.LambdaLR(new_optimizer, lr_lambda=params['lr_lamdas'],
                                                  last_epoch=(-1 if params['last_epoch'] == 0 else params[
                                                      'last_epoch']))
     elif isinstance(scheduler, torch.optim.lr_scheduler.MultiStepLR):
         return torch.optim.lr_scheduler.MultiStepLR(new_optimizer, milestones=params['milestones'],
                                                     gamma=params['gammas'],
                                                     last_epoch=(-1 if params['last_epoch'] == 0 else params[
                                                         'last_epoch']))
     elif isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
         return torch.optim.lr_scheduler.ReduceLROnPlateau(new_optimizer, mode=params['mode'],
                                                           factor=params['factor'], patience=params['patience'],
                                                           verbose=params['verbose'], threshold=params['threshold'],
                                                           threshold_mode=params['threshold_mode'],
                                                           cooldown=params['cooldown'], min_lr=params['min_lrs'],
                                                           eps=params['eps'])
     elif isinstance(scheduler, torch.optim.lr_scheduler.StepLR):
         return torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=params['step_size'], gamma=params['gamma'],
                                                last_epoch=(-1 if params['last_epoch'] == 0 else params[
                                                    'last_epoch']))
示例#3
0
    def _generate_new_optimizer(self, model, optimizer):
        state_dict = optimizer.state_dict()
        optimizer_class = optimizer.__class__
        params = state_dict['param_groups'][0]
        if not isinstance(optimizer, torch.optim.Optimizer):
            raise PFLException("optimizer get wrong type value")

        if isinstance(optimizer, torch.optim.SGD):
            return optimizer_class(model.parameters(), lr=params['lr'], momentum=params['momentum'],
                                   dampening=params['dampening'], weight_decay=params['weight_decay'],
                                   nesterov=params['nesterov'])
        else:
            return optimizer_class(model.parameters(), lr=params['lr'], betas=params['betas'],
                                   eps=params['eps'], weight_decay=params['weight_decay'],
                                   amsgrad=params['amsgrad'])
示例#4
0
 def start(self):
     if self.work_mode == WorkModeStrategy.WORKMODE_STANDALONE:
         self._trainer_standalone_exec()
     else:
         response = requests.post("/".join([
             self.server_url, "register", self.client_ip,
             '%s' % self.client_port,
             '%s' % self.client_id
         ]))
         response_json = response.json()
         if response_json['code'] == 200 or response_json['code'] == 201:
             self.trainer_executor_pool.submit(
                 communicate_client.start_communicate_client,
                 self.client_ip, self.client_port)
             self._trainer_mpc_exec()
         else:
             PFLException(
                 "connect to parameter server fail, please check your internet"
             )
示例#5
0
文件: strategy.py 项目: xl3783/GFL
 def set_optimizer(self, optimizer):
     optim_strategies = self.get_optim_strategies()
     if optimizer in optim_strategies:
         self.optimizer = optimizer
     else:
         raise PFLException("optimizer strategy not found")
示例#6
0
文件: strategy.py 项目: xl3783/GFL
 def set_loss_function(self, loss_function):
     loss_functions = self.get_loss_functions()
     if loss_function in loss_functions:
         self.loss_function = loss_function.value
     else:
         raise PFLException("loss strategy not found")